From dd76000f5455a5c22dd9ec6fc406f4a08871777f Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Mon, 21 Jul 2025 11:17:52 -0400 Subject: [PATCH 001/104] Use strands logo that looks good in dark & light mode (#505) Similar to strands-agents/sdk-python/pull/475 but using a dedicated github icon. The github icon is the lite logo but copied/renamed to make it dedicated to github --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c31048770..58c647f8d 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@
- Strands Agents + Strands Agents
From 24ccb00159c4319cfb5fd3bea4caa5b50c846539 Mon Sep 17 00:00:00 2001 From: Jeremiah Date: Tue, 22 Jul 2025 11:48:26 -0400 Subject: [PATCH 002/104] deps(a2a): address interface changes and bump min version (#515) Co-authored-by: jer --- pyproject.toml | 4 ++-- src/strands/multiagent/a2a/server.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 974ff9d94..765e815ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ writer = [ ] a2a = [ - "a2a-sdk[sql]>=0.2.11,<1.0.0", + "a2a-sdk[sql]>=0.2.16,<1.0.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", "fastapi>=0.115.12,<1.0.0", @@ -136,7 +136,7 @@ all = [ "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", # a2a - "a2a-sdk[sql]>=0.2.11,<1.0.0", + "a2a-sdk[sql]>=0.2.16,<1.0.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", "fastapi>=0.115.12,<1.0.0", diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py index 568252597..de891499d 100644 --- a/src/strands/multiagent/a2a/server.py +++ b/src/strands/multiagent/a2a/server.py @@ -83,8 +83,8 @@ def public_agent_card(self) -> AgentCard: url=self.http_url, version=self.version, skills=self.agent_skills, - defaultInputModes=["text"], - defaultOutputModes=["text"], + default_input_modes=["text"], + default_output_modes=["text"], capabilities=self.capabilities, ) From 69053420de6695ffc3921481eba04935735f55e3 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 22 Jul 2025 12:45:39 -0400 Subject: [PATCH 003/104] ci: expose STRANDS_TEST_API_KEYS_SECRET_NAME to integration tests (#513) --- .github/workflows/integration-test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index a1d86364a..c347e3805 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -67,6 +67,7 @@ jobs: env: AWS_REGION: us-east-1 AWS_REGION_NAME: us-east-1 # Needed for LiteLLM + STRANDS_TEST_API_KEYS_SECRET_NAME: ${{ secrets.STRANDS_TEST_API_KEYS_SECRET_NAME }} id: tests run: | hatch test tests_integ From 5a7076bfbd01c415fee1c2ec2316c005da9d973a Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Tue, 22 Jul 2025 14:22:17 -0400 Subject: [PATCH 004/104] Don't re-run workflows on un/approvals (#516) These were necessary when we had conditional running but we switched to needing to approve all workflows for non-maintainers, so we no longer need these. Co-authored-by: Mackenzie Zastrow --- .github/workflows/pr-and-push.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml index 2b2d026f4..b558943dd 100644 --- a/.github/workflows/pr-and-push.yml +++ b/.github/workflows/pr-and-push.yml @@ -3,7 +3,7 @@ name: Pull Request and Push Action on: pull_request: # Safer than pull_request_target for untrusted code branches: [ main ] - types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed] + types: [opened, synchronize, reopened, ready_for_review] push: branches: [ main ] # Also run on direct pushes to main concurrency: From 9aba0189abf43136a9c3eb477ee5257f735730c9 Mon Sep 17 00:00:00 2001 From: Didier Durand Date: Tue, 22 Jul 2025 21:49:29 +0200 Subject: [PATCH 005/104] Fixing some typos in various texts (#487) --- .../conversation_manager/conversation_manager.py | 2 +- src/strands/multiagent/a2a/executor.py | 2 +- src/strands/session/repository_session_manager.py | 14 +++++++------- src/strands/types/session.py | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 8756a1022..2c1ee7847 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -36,7 +36,7 @@ def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]] Args: state: Previous state of the conversation manager Returns: - Optional list of messages to prepend to the agents messages. By defualt returns None. + Optional list of messages to prepend to the agents messages. By default returns None. """ if state.get("__name__") != self.__class__.__name__: raise ValueError("Invalid conversation manager state.") diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 00eb4764f..d65c64aff 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -4,7 +4,7 @@ to be used as an executor in the A2A protocol. It handles the execution of agent requests and the conversion of Strands Agent streamed responses to A2A events. -The A2A AgentExecutor ensures clients recieve responses for synchronous and +The A2A AgentExecutor ensures clients receive responses for synchronous and streamed requests to the A2AServer. """ diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index 487335ac9..18a6ac474 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -32,7 +32,7 @@ def __init__(self, session_id: str, session_repository: SessionRepository, **kwa Args: session_id: ID to use for the session. A new session with this id will be created if it does - not exist in the reposiory yet + not exist in the repository yet session_repository: Underlying session repository to use to store the sessions state. **kwargs: Additional keyword arguments for future extensibility. @@ -133,15 +133,15 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: agent.state = AgentState(session_agent.state) # Restore the conversation manager to its previous state, and get the optional prepend messages - prepend_messsages = agent.conversation_manager.restore_from_session( + prepend_messages = agent.conversation_manager.restore_from_session( session_agent.conversation_manager_state ) - if prepend_messsages is None: - prepend_messsages = [] + if prepend_messages is None: + prepend_messages = [] # List the messages currently in the session, using an offset of the messages previously removed - # by the converstaion manager. + # by the conversation manager. session_messages = self.session_repository.list_messages( session_id=self.session_id, agent_id=agent.agent_id, @@ -150,5 +150,5 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: if len(session_messages) > 0: self._latest_agent_message[agent.agent_id] = session_messages[-1] - # Resore the agents messages array including the optional prepend messages - agent.messages = prepend_messsages + [session_message.to_message() for session_message in session_messages] + # Restore the agents messages array including the optional prepend messages + agent.messages = prepend_messages + [session_message.to_message() for session_message in session_messages] diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 259ab1171..e51816f74 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -125,7 +125,7 @@ def from_agent(cls, agent: "Agent") -> "SessionAgent": @classmethod def from_dict(cls, env: dict[str, Any]) -> "SessionAgent": - """Initialize a SessionAgent from a dictionary, ignoring keys that are not calss parameters.""" + """Initialize a SessionAgent from a dictionary, ignoring keys that are not class parameters.""" return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) def to_dict(self) -> dict[str, Any]: @@ -144,7 +144,7 @@ class Session: @classmethod def from_dict(cls, env: dict[str, Any]) -> "Session": - """Initialize a Session from a dictionary, ignoring keys that are not calss parameters.""" + """Initialize a Session from a dictionary, ignoring keys that are not class parameters.""" return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) def to_dict(self) -> dict[str, Any]: From 040ba21cdfeb5dfbcdbb6e76ec227356a4429329 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=2E/c=C2=B2?= Date: Tue, 22 Jul 2025 15:52:35 -0400 Subject: [PATCH 006/104] docs(readme): add hot reloading documentation for load_tools_from_directory (#517) - Add new section showcasing Agent(load_tools_from_directory=True) functionality - Document automatic tool loading and reloading from ./tools/ directory - Include practical code example for developers - Improve discoverability of this development feature --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 58c647f8d..62ed54d47 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,17 @@ agent = Agent(tools=[word_count]) response = agent("How many words are in this sentence?") ``` +**Hot Reloading from Directory:** +Enable automatic tool loading and reloading from the `./tools/` directory: + +```python +from strands import Agent + +# Agent will watch ./tools/ directory for changes +agent = Agent(load_tools_from_directory=True) +response = agent("Use any tools you find in the tools directory") +``` + ### MCP Support Seamlessly integrate Model Context Protocol (MCP) servers: From 022ec556d7eed2de935deb8293e86f8263056af5 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 22 Jul 2025 16:19:15 -0400 Subject: [PATCH 007/104] ci: enable integ tests for anthropic, cohere, mistral, openai, writer (#510) --- tests_integ/conftest.py | 52 +++++++++++++++++++ tests_integ/models/providers.py | 4 +- .../{conformance.py => test_conformance.py} | 4 +- tests_integ/models/test_model_anthropic.py | 13 +++-- tests_integ/models/test_model_cohere.py | 2 +- 5 files changed, 67 insertions(+), 8 deletions(-) rename tests_integ/models/{conformance.py => test_conformance.py} (81%) diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index f83f0e299..61c2bf9a1 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -1,5 +1,17 @@ +import json +import logging +import os + +import boto3 import pytest +logger = logging.getLogger(__name__) + + +def pytest_sessionstart(session): + _load_api_keys_from_secrets_manager() + + ## Data @@ -28,3 +40,43 @@ async def alist(items): return [item async for item in items] return alist + + +## Models + + +def _load_api_keys_from_secrets_manager(): + """Load API keys as environment variables from AWS Secrets Manager.""" + session = boto3.session.Session() + client = session.client(service_name="secretsmanager") + if "STRANDS_TEST_API_KEYS_SECRET_NAME" in os.environ: + try: + secret_name = os.getenv("STRANDS_TEST_API_KEYS_SECRET_NAME") + response = client.get_secret_value(SecretId=secret_name) + + if "SecretString" in response: + secret = json.loads(response["SecretString"]) + for key, value in secret.items(): + os.environ[f"{key.upper()}_API_KEY"] = str(value) + + except Exception as e: + logger.warning("Error retrieving secret", e) + + """ + Validate that required environment variables are set when running in GitHub Actions. + This prevents tests from being unintentionally skipped due to missing credentials. + """ + if os.environ.get("GITHUB_ACTIONS") != "true": + logger.warning("Tests running outside GitHub Actions, skipping required provider validation") + return + + required_providers = { + "ANTHROPIC_API_KEY", + "COHERE_API_KEY", + "MISTRAL_API_KEY", + "OPENAI_API_KEY", + "WRITER_API_KEY", + } + for provider in required_providers: + if provider not in os.environ or not os.environ[provider]: + raise ValueError(f"Missing required environment variables for {provider}") diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 543f58480..d2ac148d3 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -72,11 +72,11 @@ def __init__(self): bedrock = ProviderInfo(id="bedrock", factory=lambda: BedrockModel()) cohere = ProviderInfo( id="cohere", - environment_variable="CO_API_KEY", + environment_variable="COHERE_API_KEY", factory=lambda: OpenAIModel( client_args={ "base_url": "https://api.cohere.com/compatibility/v1", - "api_key": os.getenv("CO_API_KEY"), + "api_key": os.getenv("COHERE_API_KEY"), }, model_id="command-a-03-2025", params={"stream_options": None}, diff --git a/tests_integ/models/conformance.py b/tests_integ/models/test_conformance.py similarity index 81% rename from tests_integ/models/conformance.py rename to tests_integ/models/test_conformance.py index 262e41e42..d9875bc07 100644 --- a/tests_integ/models/conformance.py +++ b/tests_integ/models/test_conformance.py @@ -1,6 +1,6 @@ import pytest -from strands.types.models import Model +from strands.models import Model from tests_integ.models.providers import ProviderInfo, all_providers @@ -9,7 +9,7 @@ def get_models(): pytest.param( provider_info, id=provider_info.id, # Adds the provider name to the test name - marks=[provider_info.mark], # ignores tests that don't have the requirements + marks=provider_info.mark, # ignores tests that don't have the requirements ) for provider_info in all_providers ] diff --git a/tests_integ/models/test_model_anthropic.py b/tests_integ/models/test_model_anthropic.py index 2ee5e7f23..62a95d06d 100644 --- a/tests_integ/models/test_model_anthropic.py +++ b/tests_integ/models/test_model_anthropic.py @@ -6,10 +6,17 @@ import strands from strands import Agent from strands.models.anthropic import AnthropicModel -from tests_integ.models import providers -# these tests only run if we have the anthropic api key -pytestmark = providers.anthropic.mark +""" +These tests only run if we have the anthropic api key + +Because of infrequent burst usage, Anthropic tests are unreliable, failing tests with 529s. +{'type': 'error', 'error': {'details': None, 'type': 'overloaded_error', 'message': 'Overloaded'}} +https://docs.anthropic.com/en/api/errors#http-errors +""" +pytestmark = pytest.skip( + "Because of infrequent burst usage, Anthropic tests are unreliable, failing with 529s", allow_module_level=True +) @pytest.fixture diff --git a/tests_integ/models/test_model_cohere.py b/tests_integ/models/test_model_cohere.py index 996b0f326..33fb1a8c6 100644 --- a/tests_integ/models/test_model_cohere.py +++ b/tests_integ/models/test_model_cohere.py @@ -16,7 +16,7 @@ def model(): return OpenAIModel( client_args={ "base_url": "https://api.cohere.com/compatibility/v1", - "api_key": os.getenv("CO_API_KEY"), + "api_key": os.getenv("COHERE_API_KEY"), }, model_id="command-a-03-2025", params={"stream_options": None}, From e597e07f06665292c4207270f41eb37cc45fd645 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 23 Jul 2025 11:26:30 -0400 Subject: [PATCH 008/104] Automatically flatten nested tool collections (#508) Fixes issue #50 Customers naturally want to pass nested collections of tools - the above issue has gathered enough data points proving that. --- src/strands/tools/registry.py | 11 +++++++++-- tests/strands/agent/test_agent.py | 19 +++++++++++++++++++ tests/strands/tools/test_registry.py | 27 +++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 9d835d28e..fd395ae77 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -11,7 +11,7 @@ from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional from typing_extensions import TypedDict, cast @@ -54,7 +54,7 @@ def process_tools(self, tools: List[Any]) -> List[str]: """ tool_names = [] - for tool in tools: + def add_tool(tool: Any) -> None: # Case 1: String file path if isinstance(tool, str): # Extract tool name from path @@ -97,9 +97,16 @@ def process_tools(self, tools: List[Any]) -> List[str]: elif isinstance(tool, AgentTool): self.register_tool(tool) tool_names.append(tool.tool_name) + # Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool + elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): + for t in tool: + add_tool(t) else: logger.warning("tool=<%s> | unrecognized tool specification", tool) + for a_tool in tools: + add_tool(a_tool) + return tool_names def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index d6471a09a..4e310dace 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -231,6 +231,25 @@ def test_agent__init__with_string_model_id(): assert agent.model.config["model_id"] == "nonsense" +def test_agent__init__nested_tools_flattening(tool_decorated, tool_module, tool_imported, tool_registry): + _ = tool_registry + # Nested structure: [tool_decorated, [tool_module, [tool_imported]]] + agent = Agent(tools=[tool_decorated, [tool_module, [tool_imported]]]) + tru_tool_names = sorted(agent.tool_names) + exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"] + assert tru_tool_names == exp_tool_names + + +def test_agent__init__deeply_nested_tools(tool_decorated, tool_module, tool_imported, tool_registry): + _ = tool_registry + # Deeply nested structure + nested_tools = [[[[tool_decorated]], [[tool_module]], tool_imported]] + agent = Agent(tools=nested_tools) + tru_tool_names = sorted(agent.tool_names) + exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"] + assert tru_tool_names == exp_tool_names + + def test_agent__call__( mock_model, system_prompt, diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index ebcba3fb1..66494c987 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -93,3 +93,30 @@ def tool_function_4(d): assert len(tools) == 2 assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools) + + +def test_process_tools_flattens_lists_and_tuples_and_sets(): + def function() -> str: + return "done" + + tool_a = tool(name="tool_a")(function) + tool_b = tool(name="tool_b")(function) + tool_c = tool(name="tool_c")(function) + tool_d = tool(name="tool_d")(function) + tool_e = tool(name="tool_e")(function) + tool_f = tool(name="tool_f")(function) + + registry = ToolRegistry() + + all_tools = [tool_a, (tool_b, tool_c), [{tool_d, tool_e}, [tool_f]]] + + tru_tool_names = sorted(registry.process_tools(all_tools)) + exp_tool_names = [ + "tool_a", + "tool_b", + "tool_c", + "tool_d", + "tool_e", + "tool_f", + ] + assert tru_tool_names == exp_tool_names From 4f4e5efd6730fd05ae4382d5ab1715e7b363be6c Mon Sep 17 00:00:00 2001 From: Jeremiah Date: Wed, 23 Jul 2025 13:44:47 -0400 Subject: [PATCH 009/104] feat(a2a): support mounts for containerized deployments (#524) * feat(a2a): support mounts for containerized deployments * feat(a2a): escape hatch for load balancers which strip paths * feat(a2a): formatting --------- Co-authored-by: jer --- src/strands/multiagent/a2a/server.py | 75 +++- .../session/repository_session_manager.py | 4 +- tests/strands/multiagent/a2a/test_server.py | 343 ++++++++++++++++++ 3 files changed, 412 insertions(+), 10 deletions(-) diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py index de891499d..fa7b6b887 100644 --- a/src/strands/multiagent/a2a/server.py +++ b/src/strands/multiagent/a2a/server.py @@ -6,6 +6,7 @@ import logging from typing import Any, Literal +from urllib.parse import urlparse import uvicorn from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication @@ -31,6 +32,8 @@ def __init__( # AgentCard host: str = "0.0.0.0", port: int = 9000, + http_url: str | None = None, + serve_at_root: bool = False, version: str = "0.0.1", skills: list[AgentSkill] | None = None, ): @@ -40,13 +43,34 @@ def __init__( agent: The Strands Agent to wrap with A2A compatibility. host: The hostname or IP address to bind the A2A server to. Defaults to "0.0.0.0". port: The port to bind the A2A server to. Defaults to 9000. + http_url: The public HTTP URL where this agent will be accessible. If provided, + this overrides the generated URL from host/port and enables automatic + path-based mounting for load balancer scenarios. + Example: "http://my-alb.amazonaws.com/agent1" + serve_at_root: If True, forces the server to serve at root path regardless of + http_url path component. Use this when your load balancer strips path prefixes. + Defaults to False. version: The version of the agent. Defaults to "0.0.1". skills: The list of capabilities or functions the agent can perform. """ self.host = host self.port = port - self.http_url = f"http://{self.host}:{self.port}/" self.version = version + + if http_url: + # Parse the provided URL to extract components for mounting + self.public_base_url, self.mount_path = self._parse_public_url(http_url) + self.http_url = http_url.rstrip("/") + "/" + + # Override mount path if serve_at_root is requested + if serve_at_root: + self.mount_path = "" + else: + # Fall back to constructing the URL from host and port + self.public_base_url = f"http://{host}:{port}" + self.http_url = f"{self.public_base_url}/" + self.mount_path = "" + self.strands_agent = agent self.name = self.strands_agent.name self.description = self.strands_agent.description @@ -58,6 +82,25 @@ def __init__( self._agent_skills = skills logger.info("Strands' integration with A2A is experimental. Be aware of frequent breaking changes.") + def _parse_public_url(self, url: str) -> tuple[str, str]: + """Parse the public URL into base URL and mount path components. + + Args: + url: The full public URL (e.g., "http://my-alb.amazonaws.com/agent1") + + Returns: + tuple: (base_url, mount_path) where base_url is the scheme+netloc + and mount_path is the path component + + Example: + _parse_public_url("http://my-alb.amazonaws.com/agent1") + Returns: ("http://my-alb.amazonaws.com", "/agent1") + """ + parsed = urlparse(url.rstrip("/")) + base_url = f"{parsed.scheme}://{parsed.netloc}" + mount_path = parsed.path if parsed.path != "/" else "" + return base_url, mount_path + @property def public_agent_card(self) -> AgentCard: """Get the public AgentCard for this agent. @@ -119,24 +162,42 @@ def agent_skills(self, skills: list[AgentSkill]) -> None: def to_starlette_app(self) -> Starlette: """Create a Starlette application for serving this agent via HTTP. - This method creates a Starlette application that can be used to serve - the agent via HTTP using the A2A protocol. + Automatically handles path-based mounting if a mount path was derived + from the http_url parameter. Returns: Starlette: A Starlette application configured to serve this agent. """ - return A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + a2a_app = A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + + if self.mount_path: + # Create parent app and mount the A2A app at the specified path + parent_app = Starlette() + parent_app.mount(self.mount_path, a2a_app) + logger.info("Mounting A2A server at path: %s", self.mount_path) + return parent_app + + return a2a_app def to_fastapi_app(self) -> FastAPI: """Create a FastAPI application for serving this agent via HTTP. - This method creates a FastAPI application that can be used to serve - the agent via HTTP using the A2A protocol. + Automatically handles path-based mounting if a mount path was derived + from the http_url parameter. Returns: FastAPI: A FastAPI application configured to serve this agent. """ - return A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + a2a_app = A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + + if self.mount_path: + # Create parent app and mount the A2A app at the specified path + parent_app = FastAPI() + parent_app.mount(self.mount_path, a2a_app) + logger.info("Mounting A2A server at path: %s", self.mount_path) + return parent_app + + return a2a_app def serve( self, diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index 18a6ac474..75058b251 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -133,9 +133,7 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: agent.state = AgentState(session_agent.state) # Restore the conversation manager to its previous state, and get the optional prepend messages - prepend_messages = agent.conversation_manager.restore_from_session( - session_agent.conversation_manager_state - ) + prepend_messages = agent.conversation_manager.restore_from_session(session_agent.conversation_manager_state) if prepend_messages is None: prepend_messages = [] diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py index 74f470741..fc76b5f1d 100644 --- a/tests/strands/multiagent/a2a/test_server.py +++ b/tests/strands/multiagent/a2a/test_server.py @@ -509,3 +509,346 @@ def test_serve_handles_general_exception(mock_run, mock_strands_agent, caplog): assert "Strands A2A server encountered exception" in caplog.text assert "Strands A2A server has shutdown" in caplog.text + + +def test_initialization_with_http_url_no_path(mock_strands_agent): + """Test initialization with http_url containing no path.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer( + mock_strands_agent, host="0.0.0.0", port=8080, http_url="http://my-alb.amazonaws.com", skills=[] + ) + + assert a2a_agent.host == "0.0.0.0" + assert a2a_agent.port == 8080 + assert a2a_agent.http_url == "http://my-alb.amazonaws.com/" + assert a2a_agent.public_base_url == "http://my-alb.amazonaws.com" + assert a2a_agent.mount_path == "" + + +def test_initialization_with_http_url_with_path(mock_strands_agent): + """Test initialization with http_url containing a path for mounting.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer( + mock_strands_agent, host="0.0.0.0", port=8080, http_url="http://my-alb.amazonaws.com/agent1", skills=[] + ) + + assert a2a_agent.host == "0.0.0.0" + assert a2a_agent.port == 8080 + assert a2a_agent.http_url == "http://my-alb.amazonaws.com/agent1/" + assert a2a_agent.public_base_url == "http://my-alb.amazonaws.com" + assert a2a_agent.mount_path == "/agent1" + + +def test_initialization_with_https_url(mock_strands_agent): + """Test initialization with HTTPS URL.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="https://my-alb.amazonaws.com/secure-agent", skills=[]) + + assert a2a_agent.http_url == "https://my-alb.amazonaws.com/secure-agent/" + assert a2a_agent.public_base_url == "https://my-alb.amazonaws.com" + assert a2a_agent.mount_path == "/secure-agent" + + +def test_initialization_with_http_url_with_port(mock_strands_agent): + """Test initialization with http_url containing explicit port.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="http://my-server.com:8080/api/agent", skills=[]) + + assert a2a_agent.http_url == "http://my-server.com:8080/api/agent/" + assert a2a_agent.public_base_url == "http://my-server.com:8080" + assert a2a_agent.mount_path == "/api/agent" + + +def test_parse_public_url_method(mock_strands_agent): + """Test the _parse_public_url method directly.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + a2a_agent = A2AServer(mock_strands_agent, skills=[]) + + # Test various URL formats + base_url, mount_path = a2a_agent._parse_public_url("http://example.com/path") + assert base_url == "http://example.com" + assert mount_path == "/path" + + base_url, mount_path = a2a_agent._parse_public_url("https://example.com:443/deep/path") + assert base_url == "https://example.com:443" + assert mount_path == "/deep/path" + + base_url, mount_path = a2a_agent._parse_public_url("http://example.com/") + assert base_url == "http://example.com" + assert mount_path == "" + + base_url, mount_path = a2a_agent._parse_public_url("http://example.com") + assert base_url == "http://example.com" + assert mount_path == "" + + +def test_public_agent_card_with_http_url(mock_strands_agent): + """Test that public_agent_card uses the http_url when provided.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="https://my-alb.amazonaws.com/agent1", skills=[]) + + card = a2a_agent.public_agent_card + + assert isinstance(card, AgentCard) + assert card.url == "https://my-alb.amazonaws.com/agent1/" + assert card.name == "Test Agent" + assert card.description == "A test agent for unit testing" + + +def test_to_starlette_app_with_mounting(mock_strands_agent): + """Test that to_starlette_app creates mounted app when mount_path exists.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com/agent1", skills=[]) + + app = a2a_agent.to_starlette_app() + + assert isinstance(app, Starlette) + + +def test_to_starlette_app_without_mounting(mock_strands_agent): + """Test that to_starlette_app creates regular app when no mount_path.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com", skills=[]) + + app = a2a_agent.to_starlette_app() + + assert isinstance(app, Starlette) + + +def test_to_fastapi_app_with_mounting(mock_strands_agent): + """Test that to_fastapi_app creates mounted app when mount_path exists.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com/agent1", skills=[]) + + app = a2a_agent.to_fastapi_app() + + assert isinstance(app, FastAPI) + + +def test_to_fastapi_app_without_mounting(mock_strands_agent): + """Test that to_fastapi_app creates regular app when no mount_path.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com", skills=[]) + + app = a2a_agent.to_fastapi_app() + + assert isinstance(app, FastAPI) + + +def test_backwards_compatibility_without_http_url(mock_strands_agent): + """Test that the old behavior is preserved when http_url is not provided.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, host="localhost", port=9000, skills=[]) + + # Should behave exactly like before + assert a2a_agent.host == "localhost" + assert a2a_agent.port == 9000 + assert a2a_agent.http_url == "http://localhost:9000/" + assert a2a_agent.public_base_url == "http://localhost:9000" + assert a2a_agent.mount_path == "" + + # Agent card should use the traditional URL + card = a2a_agent.public_agent_card + assert card.url == "http://localhost:9000/" + + +def test_mount_path_logging(mock_strands_agent, caplog): + """Test that mounting logs the correct message.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com/test-agent", skills=[]) + + # Test Starlette app mounting logs + caplog.clear() + a2a_agent.to_starlette_app() + assert "Mounting A2A server at path: /test-agent" in caplog.text + + # Test FastAPI app mounting logs + caplog.clear() + a2a_agent.to_fastapi_app() + assert "Mounting A2A server at path: /test-agent" in caplog.text + + +def test_http_url_trailing_slash_handling(mock_strands_agent): + """Test that trailing slashes in http_url are handled correctly.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Test with trailing slash + a2a_agent1 = A2AServer(mock_strands_agent, http_url="http://example.com/agent1/", skills=[]) + + # Test without trailing slash + a2a_agent2 = A2AServer(mock_strands_agent, http_url="http://example.com/agent1", skills=[]) + + # Both should result in the same normalized URL + assert a2a_agent1.http_url == "http://example.com/agent1/" + assert a2a_agent2.http_url == "http://example.com/agent1/" + assert a2a_agent1.mount_path == "/agent1" + assert a2a_agent2.mount_path == "/agent1" + + +def test_serve_at_root_default_behavior(mock_strands_agent): + """Test default behavior extracts mount path from http_url.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + server = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", skills=[]) + + assert server.mount_path == "/agent1" + assert server.http_url == "http://my-alb.com/agent1/" + + +def test_serve_at_root_overrides_mounting(mock_strands_agent): + """Test serve_at_root=True overrides automatic path mounting.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + server = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", serve_at_root=True, skills=[]) + + assert server.mount_path == "" # Should be empty despite path in URL + assert server.http_url == "http://my-alb.com/agent1/" # Public URL unchanged + + +def test_serve_at_root_with_no_path(mock_strands_agent): + """Test serve_at_root=True when no path in URL (redundant but valid).""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + server = A2AServer(mock_strands_agent, host="localhost", port=8080, serve_at_root=True, skills=[]) + + assert server.mount_path == "" + assert server.http_url == "http://localhost:8080/" + + +def test_serve_at_root_complex_path(mock_strands_agent): + """Test serve_at_root=True with complex nested paths.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + server = A2AServer( + mock_strands_agent, http_url="http://api.example.com/v1/agents/my-agent", serve_at_root=True, skills=[] + ) + + assert server.mount_path == "" + assert server.http_url == "http://api.example.com/v1/agents/my-agent/" + + +def test_serve_at_root_fastapi_mounting_behavior(mock_strands_agent): + """Test FastAPI mounting behavior with serve_at_root.""" + from fastapi.testclient import TestClient + + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Normal mounting + server_mounted = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", skills=[]) + app_mounted = server_mounted.to_fastapi_app() + client_mounted = TestClient(app_mounted) + + # Should work at mounted path + response = client_mounted.get("/agent1/.well-known/agent.json") + assert response.status_code == 200 + + # Should not work at root + response = client_mounted.get("/.well-known/agent.json") + assert response.status_code == 404 + + +def test_serve_at_root_fastapi_root_behavior(mock_strands_agent): + """Test FastAPI serve_at_root behavior.""" + from fastapi.testclient import TestClient + + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Serve at root + server_root = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", serve_at_root=True, skills=[]) + app_root = server_root.to_fastapi_app() + client_root = TestClient(app_root) + + # Should work at root + response = client_root.get("/.well-known/agent.json") + assert response.status_code == 200 + + # Should not work at mounted path (since we're serving at root) + response = client_root.get("/agent1/.well-known/agent.json") + assert response.status_code == 404 + + +def test_serve_at_root_starlette_behavior(mock_strands_agent): + """Test Starlette serve_at_root behavior.""" + from starlette.testclient import TestClient + + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Normal mounting + server_mounted = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", skills=[]) + app_mounted = server_mounted.to_starlette_app() + client_mounted = TestClient(app_mounted) + + # Should work at mounted path + response = client_mounted.get("/agent1/.well-known/agent.json") + assert response.status_code == 200 + + # Serve at root + server_root = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", serve_at_root=True, skills=[]) + app_root = server_root.to_starlette_app() + client_root = TestClient(app_root) + + # Should work at root + response = client_root.get("/.well-known/agent.json") + assert response.status_code == 200 + + +def test_serve_at_root_alb_scenarios(mock_strands_agent): + """Test common ALB deployment scenarios.""" + from fastapi.testclient import TestClient + + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # ALB with path preservation + server_preserved = A2AServer(mock_strands_agent, http_url="http://my-alb.amazonaws.com/agent1", skills=[]) + app_preserved = server_preserved.to_fastapi_app() + client_preserved = TestClient(app_preserved) + + # Container receives /agent1/.well-known/agent.json + response = client_preserved.get("/agent1/.well-known/agent.json") + assert response.status_code == 200 + agent_data = response.json() + assert agent_data["url"] == "http://my-alb.amazonaws.com/agent1/" + + # ALB with path stripping + server_stripped = A2AServer( + mock_strands_agent, http_url="http://my-alb.amazonaws.com/agent1", serve_at_root=True, skills=[] + ) + app_stripped = server_stripped.to_fastapi_app() + client_stripped = TestClient(app_stripped) + + # Container receives /.well-known/agent.json (path stripped by ALB) + response = client_stripped.get("/.well-known/agent.json") + assert response.status_code == 200 + agent_data = response.json() + assert agent_data["url"] == "http://my-alb.amazonaws.com/agent1/" + + +def test_serve_at_root_edge_cases(mock_strands_agent): + """Test edge cases for serve_at_root parameter.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Root path in URL + server1 = A2AServer(mock_strands_agent, http_url="http://example.com/", skills=[]) + assert server1.mount_path == "" + + # serve_at_root should be redundant but not cause issues + server2 = A2AServer(mock_strands_agent, http_url="http://example.com/", serve_at_root=True, skills=[]) + assert server2.mount_path == "" + + # Multiple nested paths + server3 = A2AServer( + mock_strands_agent, http_url="http://api.example.com/v1/agents/team1/agent1", serve_at_root=True, skills=[] + ) + assert server3.mount_path == "" + assert server3.http_url == "http://api.example.com/v1/agents/team1/agent1/" From b30e7e6e41e7a2dce70d74e8c1753503959f3619 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Wed, 23 Jul 2025 15:20:28 -0400 Subject: [PATCH 010/104] fix: include agent trace into tool for agent as tools (#526) --- src/strands/telemetry/tracer.py | 2 +- src/strands/tools/executor.py | 37 ++++++++++++++++----------------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index eebffef29..802865189 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -273,7 +273,7 @@ def end_model_invoke_span( self._end_span(span, attributes, error) - def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Optional[Span]: + def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Span: """Start a new span for a tool call. Args: diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index 1214fa608..d90f9a5aa 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -5,7 +5,7 @@ import time from typing import Any, Optional, cast -from opentelemetry import trace +from opentelemetry import trace as trace_api from ..telemetry.metrics import EventLoopMetrics, Trace from ..telemetry.tracer import get_tracer @@ -23,7 +23,7 @@ async def run_tools( invalid_tool_use_ids: list[str], tool_results: list[ToolResult], cycle_trace: Trace, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[trace_api.Span] = None, ) -> ToolGenerator: """Execute tools concurrently. @@ -53,24 +53,23 @@ async def work( tool_name = tool_use["name"] tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) tool_start_time = time.time() + with trace_api.use_span(tool_call_span): + try: + async for event in handler(tool_use): + worker_queue.put_nowait((worker_id, event)) + await worker_event.wait() + worker_event.clear() + + result = cast(ToolResult, event) + finally: + worker_queue.put_nowait((worker_id, stop_event)) + + tool_success = result.get("status") == "success" + tool_duration = time.time() - tool_start_time + message = Message(role="user", content=[{"toolResult": result}]) + event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) + cycle_trace.add_child(tool_trace) - try: - async for event in handler(tool_use): - worker_queue.put_nowait((worker_id, event)) - await worker_event.wait() - worker_event.clear() - - result = cast(ToolResult, event) - finally: - worker_queue.put_nowait((worker_id, stop_event)) - - tool_success = result.get("status") == "success" - tool_duration = time.time() - tool_start_time - message = Message(role="user", content=[{"toolResult": result}]) - event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) - cycle_trace.add_child(tool_trace) - - if tool_call_span: tracer.end_tool_call_span(tool_call_span, result) return result From 8c5562575f8c6c26c2b2a18591d1d5926a96514a Mon Sep 17 00:00:00 2001 From: Davide Gallitelli Date: Mon, 28 Jul 2025 13:34:04 +0200 Subject: [PATCH 011/104] Support for Amazon SageMaker AI endpoints as Model Provider (#176) --- pyproject.toml | 18 +- src/strands/models/sagemaker.py | 600 +++++++++++++++++++++ tests/strands/models/test_sagemaker.py | 574 ++++++++++++++++++++ tests_integ/models/test_model_sagemaker.py | 76 +++ 4 files changed, 1262 insertions(+), 6 deletions(-) create mode 100644 src/strands/models/sagemaker.py create mode 100644 tests/strands/models/test_sagemaker.py create mode 100644 tests_integ/models/test_model_sagemaker.py diff --git a/pyproject.toml b/pyproject.toml index 765e815ef..745c80e0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,8 +89,14 @@ writer = [ "writer-sdk>=2.2.0,<3.0.0" ] +sagemaker = [ + "boto3>=1.26.0,<2.0.0", + "botocore>=1.29.0,<2.0.0", + "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0" +] + a2a = [ - "a2a-sdk[sql]>=0.2.16,<1.0.0", + "a2a-sdk[sql]>=0.2.11,<1.0.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", "fastapi>=0.115.12,<1.0.0", @@ -136,7 +142,7 @@ all = [ "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", # a2a - "a2a-sdk[sql]>=0.2.16,<1.0.0", + "a2a-sdk[sql]>=0.2.11,<1.0.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", "fastapi>=0.115.12,<1.0.0", @@ -148,7 +154,7 @@ all = [ source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -171,7 +177,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -187,7 +193,7 @@ extra-args = [ [tool.hatch.envs.dev] dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a"] +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a", "sagemaker"] [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] @@ -315,4 +321,4 @@ style = [ ["instruction", ""], ["text", ""], ["disabled", "fg:#858585 italic"] -] +] \ No newline at end of file diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py new file mode 100644 index 000000000..bb2db45a2 --- /dev/null +++ b/src/strands/models/sagemaker.py @@ -0,0 +1,600 @@ +"""Amazon SageMaker model provider.""" + +import json +import logging +import os +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast + +import boto3 +from botocore.config import Config as BotocoreConfig +from mypy_boto3_sagemaker_runtime import SageMakerRuntimeClient +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import ContentBlock, Messages +from ..types.streaming import StreamEvent +from ..types.tools import ToolResult, ToolSpec +from .openai import OpenAIModel + +T = TypeVar("T", bound=BaseModel) + +logger = logging.getLogger(__name__) + + +@dataclass +class UsageMetadata: + """Usage metadata for the model. + + Attributes: + total_tokens: Total number of tokens used in the request + completion_tokens: Number of tokens used in the completion + prompt_tokens: Number of tokens used in the prompt + prompt_tokens_details: Additional information about the prompt tokens (optional) + """ + + total_tokens: int + completion_tokens: int + prompt_tokens: int + prompt_tokens_details: Optional[int] = 0 + + +@dataclass +class FunctionCall: + """Function call for the model. + + Attributes: + name: Name of the function to call + arguments: Arguments to pass to the function + """ + + name: Union[str, dict[Any, Any]] + arguments: Union[str, dict[Any, Any]] + + def __init__(self, **kwargs: dict[str, str]): + """Initialize function call. + + Args: + **kwargs: Keyword arguments for the function call. + """ + self.name = kwargs.get("name", "") + self.arguments = kwargs.get("arguments", "") + + +@dataclass +class ToolCall: + """Tool call for the model object. + + Attributes: + id: Tool call ID + type: Tool call type + function: Tool call function + """ + + id: str + type: Literal["function"] + function: FunctionCall + + def __init__(self, **kwargs: dict): + """Initialize tool call object. + + Args: + **kwargs: Keyword arguments for the tool call. + """ + self.id = str(kwargs.get("id", "")) + self.type = "function" + self.function = FunctionCall(**kwargs.get("function", {"name": "", "arguments": ""})) + + +class SageMakerAIModel(OpenAIModel): + """Amazon SageMaker model provider implementation.""" + + client: SageMakerRuntimeClient # type: ignore[assignment] + + class SageMakerAIPayloadSchema(TypedDict, total=False): + """Payload schema for the Amazon SageMaker AI model. + + Attributes: + max_tokens: Maximum number of tokens to generate in the completion + stream: Whether to stream the response + temperature: Sampling temperature to use for the model (optional) + top_p: Nucleus sampling parameter (optional) + top_k: Top-k sampling parameter (optional) + stop: List of stop sequences to use for the model (optional) + tool_results_as_user_messages: Convert tool result to user messages (optional) + additional_args: Additional request parameters, as supported by https://bit.ly/djl-lmi-request-schema + """ + + max_tokens: int + stream: bool + temperature: Optional[float] + top_p: Optional[float] + top_k: Optional[int] + stop: Optional[list[str]] + tool_results_as_user_messages: Optional[bool] + additional_args: Optional[dict[str, Any]] + + class SageMakerAIEndpointConfig(TypedDict, total=False): + """Configuration options for SageMaker models. + + Attributes: + endpoint_name: The name of the SageMaker endpoint to invoke + inference_component_name: The name of the inference component to use + + additional_args: Other request parameters, as supported by https://bit.ly/sagemaker-invoke-endpoint-params + """ + + endpoint_name: str + region_name: str + inference_component_name: Union[str, None] + target_model: Union[Optional[str], None] + target_variant: Union[Optional[str], None] + additional_args: Optional[dict[str, Any]] + + def __init__( + self, + endpoint_config: SageMakerAIEndpointConfig, + payload_config: SageMakerAIPayloadSchema, + boto_session: Optional[boto3.Session] = None, + boto_client_config: Optional[BotocoreConfig] = None, + ): + """Initialize provider instance. + + Args: + endpoint_config: Endpoint configuration for SageMaker. + payload_config: Payload configuration for the model. + boto_session: Boto Session to use when calling the SageMaker Runtime. + boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client. + """ + payload_config.setdefault("stream", True) + payload_config.setdefault("tool_results_as_user_messages", False) + self.endpoint_config = dict(endpoint_config) + self.payload_config = dict(payload_config) + logger.debug( + "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config + ) + + region = self.endpoint_config.get("region_name") or os.getenv("AWS_REGION") or "us-west-2" + session = boto_session or boto3.Session(region_name=str(region)) + + # Add strands-agents to the request user agent + if boto_client_config: + existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) + + # Append 'strands-agents' to existing user_agent_extra or set it if not present + new_user_agent = f"{existing_user_agent} strands-agents" if existing_user_agent else "strands-agents" + + client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + else: + client_config = BotocoreConfig(user_agent_extra="strands-agents") + + self.client = session.client( + service_name="sagemaker-runtime", + config=client_config, + ) + + @override + def update_config(self, **endpoint_config: Unpack[SageMakerAIEndpointConfig]) -> None: # type: ignore[override] + """Update the Amazon SageMaker model configuration with the provided arguments. + + Args: + **endpoint_config: Configuration overrides. + """ + self.endpoint_config.update(endpoint_config) + + @override + def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: ignore[override] + """Get the Amazon SageMaker model configuration. + + Returns: + The Amazon SageMaker model configuration. + """ + return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config) + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format an Amazon SageMaker chat streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An Amazon SageMaker chat streaming request. + """ + formatted_messages = self.format_request_messages(messages, system_prompt) + + payload = { + "messages": formatted_messages, + "tools": [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs or [] + ], + # Add payload configuration parameters + **{ + k: v + for k, v in self.payload_config.items() + if k not in ["additional_args", "tool_results_as_user_messages"] + }, + } + + # Remove tools and tool_choice if tools = [] + if not payload["tools"]: + payload.pop("tools") + payload.pop("tool_choice", None) + else: + # Ensure the model can use tools when available + payload["tool_choice"] = "auto" + + for message in payload["messages"]: # type: ignore + # Assistant message must have either content or tool_calls, but not both + if message.get("role", "") == "assistant" and message.get("tool_calls", []) != []: + message.pop("content", None) + if message.get("role") == "tool" and self.payload_config.get("tool_results_as_user_messages", False): + # Convert tool message to user message + tool_call_id = message.get("tool_call_id", "ABCDEF") + content = message.get("content", "") + message = {"role": "user", "content": f"Tool call ID '{tool_call_id}' returned: {content}"} + # Cannot have both reasoning_text and text - if "text", content becomes an array of content["text"] + for c in message.get("content", []): + if "text" in c: + message["content"] = [c] + break + # Cast message content to string for TGI compatibility + # message["content"] = str(message.get("content", "")) + + logger.info("payload=<%s>", json.dumps(payload, indent=2)) + # Format the request according to the SageMaker Runtime API requirements + request = { + "EndpointName": self.endpoint_config["endpoint_name"], + "Body": json.dumps(payload), + "ContentType": "application/json", + "Accept": "application/json", + } + + # Add optional SageMaker parameters if provided + if self.endpoint_config.get("inference_component_name"): + request["InferenceComponentName"] = self.endpoint_config["inference_component_name"] + if self.endpoint_config.get("target_model"): + request["TargetModel"] = self.endpoint_config["target_model"] + if self.endpoint_config.get("target_variant"): + request["TargetVariant"] = self.endpoint_config["target_variant"] + + # Add additional args if provided + if self.endpoint_config.get("additional_args"): + request.update(self.endpoint_config["additional_args"].__dict__) + + print(json.dumps(request["Body"], indent=2)) + + return request + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the SageMaker model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Formatted message chunks from the model. + """ + logger.debug("formatting request") + request = self.format_request(messages, tool_specs, system_prompt) + logger.debug("formatted request=<%s>", request) + + logger.debug("invoking model") + try: + if self.payload_config.get("stream", True): + response = self.client.invoke_endpoint_with_response_stream(**request) + + # Message start + yield self.format_chunk({"chunk_type": "message_start"}) + + # Parse the content + finish_reason = "" + partial_content = "" + tool_calls: dict[int, list[Any]] = {} + has_text_content = False + text_content_started = False + reasoning_content_started = False + + for event in response["Body"]: + chunk = event["PayloadPart"]["Bytes"].decode("utf-8") + partial_content += chunk[6:] if chunk.startswith("data: ") else chunk # TGI fix + logger.info("chunk=<%s>", partial_content) + try: + content = json.loads(partial_content) + partial_content = "" + choice = content["choices"][0] + logger.info("choice=<%s>", json.dumps(choice, indent=2)) + + # Handle text content + if choice["delta"].get("content", None): + if not text_content_started: + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + text_content_started = True + has_text_content = True + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "text", + "data": choice["delta"]["content"], + } + ) + + # Handle reasoning content + if choice["delta"].get("reasoning_content", None): + if not reasoning_content_started: + yield self.format_chunk( + {"chunk_type": "content_start", "data_type": "reasoning_content"} + ) + reasoning_content_started = True + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": choice["delta"]["reasoning_content"], + } + ) + + # Handle tool calls + generated_tool_calls = choice["delta"].get("tool_calls", []) + if not isinstance(generated_tool_calls, list): + generated_tool_calls = [generated_tool_calls] + for tool_call in generated_tool_calls: + tool_calls.setdefault(tool_call["index"], []).append(tool_call) + + if choice["finish_reason"] is not None: + finish_reason = choice["finish_reason"] + break + + if choice.get("usage", None): + yield self.format_chunk( + {"chunk_type": "metadata", "data": UsageMetadata(**choice["usage"])} + ) + + except json.JSONDecodeError: + # Continue accumulating content until we have valid JSON + continue + + # Close reasoning content if it was started + if reasoning_content_started: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"}) + + # Close text content if it was started + if text_content_started: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + + # Handle tool calling + logger.info("tool_calls=<%s>", json.dumps(tool_calls, indent=2)) + for tool_deltas in tool_calls.values(): + if not tool_deltas[0]["function"].get("name", None): + raise Exception("The model did not provide a tool name.") + yield self.format_chunk( + {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_deltas[0])} + ) + for tool_delta in tool_deltas: + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_delta)} + ) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + # If no content was generated at all, ensure we have empty text content + if not has_text_content and not tool_calls: + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + + # Message close + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + else: + # Not all SageMaker AI models support streaming! + response = self.client.invoke_endpoint(**request) # type: ignore[assignment] + final_response_json = json.loads(response["Body"].read().decode("utf-8")) # type: ignore[attr-defined] + logger.info("response=<%s>", json.dumps(final_response_json, indent=2)) + + # Obtain the key elements from the response + message = final_response_json["choices"][0]["message"] + message_stop_reason = final_response_json["choices"][0]["finish_reason"] + + # Message start + yield self.format_chunk({"chunk_type": "message_start"}) + + # Handle text + if message.get("content", ""): + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": message["content"]} + ) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + + # Handle reasoning content + if message.get("reasoning_content", None): + yield self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"}) + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": message["reasoning_content"], + } + ) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"}) + + # Handle the tool calling, if any + if message.get("tool_calls", None) or message_stop_reason == "tool_calls": + if not isinstance(message["tool_calls"], list): + message["tool_calls"] = [message["tool_calls"]] + for tool_call in message["tool_calls"]: + # if arguments of tool_call is not str, cast it + if not isinstance(tool_call["function"]["arguments"], str): + tool_call["function"]["arguments"] = json.dumps(tool_call["function"]["arguments"]) + yield self.format_chunk( + {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_call)} + ) + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_call)} + ) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + message_stop_reason = "tool_calls" + + # Message close + yield self.format_chunk({"chunk_type": "message_stop", "data": message_stop_reason}) + # Handle usage metadata + if final_response_json.get("usage", None): + yield self.format_chunk( + {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json.get("usage", None))} + ) + except ( + self.client.exceptions.InternalFailure, + self.client.exceptions.ServiceUnavailable, + self.client.exceptions.ValidationError, + self.client.exceptions.ModelError, + self.client.exceptions.InternalDependencyException, + self.client.exceptions.ModelNotReadyException, + ) as e: + logger.error("SageMaker error: %s", str(e)) + raise e + + logger.debug("finished streaming response from model") + + @override + @classmethod + def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + """Format a SageMaker compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + SageMaker compatible tool message with content as a string. + """ + # Convert content blocks to a simple string for SageMaker compatibility + content_parts = [] + for content in tool_result["content"]: + if "json" in content: + content_parts.append(json.dumps(content["json"])) + elif "text" in content: + content_parts.append(content["text"]) + else: + # Handle other content types by converting to string + content_parts.append(str(content)) + + content_string = " ".join(content_parts) + + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": content_string, # String instead of list + } + + @override + @classmethod + def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + """Format a content block. + + Args: + content: Message content. + + Returns: + Formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to a SageMaker-compatible format. + """ + # if "text" in content and not isinstance(content["text"], str): + # return {"type": "text", "text": str(content["text"])} + + if "reasoningContent" in content and content["reasoningContent"]: + return { + "signature": content["reasoningContent"].get("reasoningText", {}).get("signature", ""), + "thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""), + "type": "thinking", + } + elif not content.get("reasoningContent", None): + content.pop("reasoningContent", None) + + if "video" in content: + return { + "type": "video_url", + "video_url": { + "detail": "auto", + "url": content["video"]["source"]["bytes"], + }, + } + + return super().format_request_message_content(content) + + @override + async def structured_output( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + """Get structured output from the model. + + Args: + output_model: The output model to use for the agent. + prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Model events with the last being the structured output. + """ + # Format the request for structured output + request = self.format_request(prompt, system_prompt=system_prompt) + + # Parse the payload to add response format + payload = json.loads(request["Body"]) + payload["response_format"] = { + "type": "json_schema", + "json_schema": {"name": output_model.__name__, "schema": output_model.model_json_schema(), "strict": True}, + } + request["Body"] = json.dumps(payload) + + try: + # Use non-streaming mode for structured output + response = self.client.invoke_endpoint(**request) + final_response_json = json.loads(response["Body"].read().decode("utf-8")) + + # Extract the structured content + message = final_response_json["choices"][0]["message"] + + if message.get("content"): + try: + # Parse the JSON content and create the output model instance + content_data = json.loads(message["content"]) + parsed_output = output_model(**content_data) + yield {"output": parsed_output} + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse structured output: {e}") from e + else: + raise ValueError("No content found in SageMaker response") + + except ( + self.client.exceptions.InternalFailure, + self.client.exceptions.ServiceUnavailable, + self.client.exceptions.ValidationError, + self.client.exceptions.ModelError, + self.client.exceptions.InternalDependencyException, + self.client.exceptions.ModelNotReadyException, + ) as e: + logger.error("SageMaker structured output error: %s", str(e)) + raise ValueError(f"SageMaker structured output error: {str(e)}") from e diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py new file mode 100644 index 000000000..ba395b2d6 --- /dev/null +++ b/tests/strands/models/test_sagemaker.py @@ -0,0 +1,574 @@ +"""Tests for the Amazon SageMaker model provider.""" + +import json +import unittest.mock +from typing import Any, Dict, List + +import boto3 +import pytest +from botocore.config import Config as BotocoreConfig + +from strands.models.sagemaker import ( + FunctionCall, + SageMakerAIModel, + ToolCall, + UsageMetadata, +) +from strands.types.content import Messages +from strands.types.tools import ToolSpec + + +@pytest.fixture +def boto_session(): + """Mock boto3 session.""" + with unittest.mock.patch.object(boto3, "Session") as mock_session: + yield mock_session.return_value + + +@pytest.fixture +def sagemaker_client(boto_session): + """Mock SageMaker runtime client.""" + return boto_session.client.return_value + + +@pytest.fixture +def endpoint_config() -> Dict[str, Any]: + """Default endpoint configuration for tests.""" + return { + "endpoint_name": "test-endpoint", + "inference_component_name": "test-component", + "region_name": "us-east-1", + } + + +@pytest.fixture +def payload_config() -> Dict[str, Any]: + """Default payload configuration for tests.""" + return { + "max_tokens": 1024, + "temperature": 0.7, + "stream": True, + } + + +@pytest.fixture +def model(boto_session, endpoint_config, payload_config): + """SageMaker model instance with mocked boto session.""" + return SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session) + + +@pytest.fixture +def messages() -> Messages: + """Sample messages for testing.""" + return [{"role": "user", "content": [{"text": "What is the capital of France?"}]}] + + +@pytest.fixture +def tool_specs() -> List[ToolSpec]: + """Sample tool specifications for testing.""" + return [ + { + "name": "get_weather", + "description": "Get the weather for a location", + "inputSchema": { + "json": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + } + }, + } + ] + + +@pytest.fixture +def system_prompt() -> str: + """Sample system prompt for testing.""" + return "You are a helpful assistant." + + +class TestSageMakerAIModel: + """Test suite for SageMakerAIModel.""" + + def test_init_default(self, boto_session): + """Test initialization with default parameters.""" + endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"} + payload_config = {"max_tokens": 1024} + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + assert model.endpoint_config["endpoint_name"] == "test-endpoint" + assert model.payload_config.get("stream", True) is True + + boto_session.client.assert_called_once_with( + service_name="sagemaker-runtime", + config=unittest.mock.ANY, + ) + + def test_init_with_all_params(self, boto_session): + """Test initialization with all parameters.""" + endpoint_config = { + "endpoint_name": "test-endpoint", + "inference_component_name": "test-component", + "region_name": "us-west-2", + } + payload_config = { + "stream": False, + "max_tokens": 1024, + "temperature": 0.7, + } + client_config = BotocoreConfig(user_agent_extra="test-agent") + + model = SageMakerAIModel( + endpoint_config=endpoint_config, + payload_config=payload_config, + boto_session=boto_session, + boto_client_config=client_config, + ) + + assert model.endpoint_config["endpoint_name"] == "test-endpoint" + assert model.endpoint_config["inference_component_name"] == "test-component" + assert model.payload_config["stream"] is False + assert model.payload_config["max_tokens"] == 1024 + assert model.payload_config["temperature"] == 0.7 + + boto_session.client.assert_called_once_with( + service_name="sagemaker-runtime", + config=unittest.mock.ANY, + ) + + def test_init_with_client_config(self, boto_session): + """Test initialization with client configuration.""" + endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"} + payload_config = {"max_tokens": 1024} + client_config = BotocoreConfig(user_agent_extra="test-agent") + + SageMakerAIModel( + endpoint_config=endpoint_config, + payload_config=payload_config, + boto_session=boto_session, + boto_client_config=client_config, + ) + + # Verify client was created with a config that includes our user agent + boto_session.client.assert_called_once_with( + service_name="sagemaker-runtime", + config=unittest.mock.ANY, + ) + + # Get the actual config passed to client + actual_config = boto_session.client.call_args[1]["config"] + assert "strands-agents" in actual_config.user_agent_extra + assert "test-agent" in actual_config.user_agent_extra + + def test_update_config(self, model): + """Test updating model configuration.""" + new_config = {"target_model": "new-model", "target_variant": "new-variant"} + model.update_config(**new_config) + + assert model.endpoint_config["target_model"] == "new-model" + assert model.endpoint_config["target_variant"] == "new-variant" + # Original values should be preserved + assert model.endpoint_config["endpoint_name"] == "test-endpoint" + assert model.endpoint_config["inference_component_name"] == "test-component" + + def test_get_config(self, model, endpoint_config): + """Test getting model configuration.""" + config = model.get_config() + assert config == model.endpoint_config + assert isinstance(config, dict) + + # def test_format_request_messages_with_system_prompt(self, model): + # """Test formatting request messages with system prompt.""" + # messages = [{"role": "user", "content": "Hello"}] + # system_prompt = "You are a helpful assistant." + + # formatted_messages = model.format_request_messages(messages, system_prompt) + + # assert len(formatted_messages) == 2 + # assert formatted_messages[0]["role"] == "system" + # assert formatted_messages[0]["content"] == system_prompt + # assert formatted_messages[1]["role"] == "user" + # assert formatted_messages[1]["content"] == "Hello" + + # def test_format_request_messages_with_tool_calls(self, model): + # """Test formatting request messages with tool calls.""" + # messages = [ + # {"role": "user", "content": "Hello"}, + # { + # "role": "assistant", + # "content": None, + # "tool_calls": [{"id": "123", "type": "function", "function": {"name": "test", "arguments": "{}"}}], + # }, + # ] + + # formatted_messages = model.format_request_messages(messages, None) + + # assert len(formatted_messages) == 2 + # assert formatted_messages[0]["role"] == "user" + # assert formatted_messages[1]["role"] == "assistant" + # assert "content" not in formatted_messages[1] + # assert "tool_calls" in formatted_messages[1] + + # def test_format_request(self, model, messages, tool_specs, system_prompt): + # """Test formatting a request with all parameters.""" + # request = model.format_request(messages, tool_specs, system_prompt) + + # assert request["EndpointName"] == "test-endpoint" + # assert request["InferenceComponentName"] == "test-component" + # assert request["ContentType"] == "application/json" + # assert request["Accept"] == "application/json" + + # payload = json.loads(request["Body"]) + # assert "messages" in payload + # assert len(payload["messages"]) > 0 + # assert "tools" in payload + # assert len(payload["tools"]) == 1 + # assert payload["tools"][0]["type"] == "function" + # assert payload["tools"][0]["function"]["name"] == "get_weather" + # assert payload["max_tokens"] == 1024 + # assert payload["temperature"] == 0.7 + # assert payload["stream"] is True + + # def test_format_request_without_tools(self, model, messages, system_prompt): + # """Test formatting a request without tools.""" + # request = model.format_request(messages, None, system_prompt) + + # payload = json.loads(request["Body"]) + # assert "tools" in payload + # assert payload["tools"] == [] + + @pytest.mark.asyncio + async def test_stream_with_streaming_enabled(self, sagemaker_client, model, messages): + """Test streaming response with streaming enabled.""" + # Mock the response from SageMaker + mock_response = { + "Body": [ + { + "PayloadPart": { + "Bytes": json.dumps( + { + "choices": [ + { + "delta": {"content": "Paris is the capital of France."}, + "finish_reason": None, + } + ] + } + ).encode("utf-8") + } + }, + { + "PayloadPart": { + "Bytes": json.dumps( + { + "choices": [ + { + "delta": {"content": " It is known for the Eiffel Tower."}, + "finish_reason": "stop", + } + ] + } + ).encode("utf-8") + } + }, + ] + } + sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response + + response = [chunk async for chunk in model.stream(messages)] + + assert len(response) >= 5 + assert response[0] == {"messageStart": {"role": "assistant"}} + + # Find content events + content_start = next((e for e in response if "contentBlockStart" in e), None) + content_delta = next((e for e in response if "contentBlockDelta" in e), None) + content_stop = next((e for e in response if "contentBlockStop" in e), None) + message_stop = next((e for e in response if "messageStop" in e), None) + + assert content_start is not None + assert content_delta is not None + assert content_stop is not None + assert message_stop is not None + assert message_stop["messageStop"]["stopReason"] == "end_turn" + + sagemaker_client.invoke_endpoint_with_response_stream.assert_called_once() + + @pytest.mark.asyncio + async def test_stream_with_tool_calls(self, sagemaker_client, model, messages): + """Test streaming response with tool calls.""" + # Mock the response from SageMaker with tool calls + mock_response = { + "Body": [ + { + "PayloadPart": { + "Bytes": json.dumps( + { + "choices": [ + { + "delta": { + "content": None, + "tool_calls": [ + { + "index": 0, + "id": "tool123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "Paris"}', + }, + } + ], + }, + "finish_reason": "tool_calls", + } + ] + } + ).encode("utf-8") + } + } + ] + } + sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response + + response = [chunk async for chunk in model.stream(messages)] + + # Verify the response contains tool call events + assert len(response) >= 4 + assert response[0] == {"messageStart": {"role": "assistant"}} + + message_stop = next((e for e in response if "messageStop" in e), None) + assert message_stop is not None + assert message_stop["messageStop"]["stopReason"] == "tool_use" + + # Find tool call events + tool_start = next( + ( + e + for e in response + if "contentBlockStart" in e and e.get("contentBlockStart", {}).get("start", {}).get("toolUse") + ), + None, + ) + tool_delta = next( + ( + e + for e in response + if "contentBlockDelta" in e and e.get("contentBlockDelta", {}).get("delta", {}).get("toolUse") + ), + None, + ) + tool_stop = next((e for e in response if "contentBlockStop" in e), None) + + assert tool_start is not None + assert tool_delta is not None + assert tool_stop is not None + + # Verify tool call data + tool_use_data = tool_start["contentBlockStart"]["start"]["toolUse"] + assert tool_use_data["toolUseId"] == "tool123" + assert tool_use_data["name"] == "get_weather" + + @pytest.mark.asyncio + async def test_stream_with_partial_json(self, sagemaker_client, model, messages): + """Test streaming response with partial JSON chunks.""" + # Mock the response from SageMaker with split JSON + mock_response = { + "Body": [ + {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}}, + {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}}, + ] + } + sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response + + response = [chunk async for chunk in model.stream(messages)] + + assert len(response) == 5 + assert response[0] == {"messageStart": {"role": "assistant"}} + + # Find content events + content_start = next((e for e in response if "contentBlockStart" in e), None) + content_delta = next((e for e in response if "contentBlockDelta" in e), None) + content_stop = next((e for e in response if "contentBlockStop" in e), None) + message_stop = next((e for e in response if "messageStop" in e), None) + + assert content_start is not None + assert content_delta is not None + assert content_stop is not None + assert message_stop is not None + assert message_stop["messageStop"]["stopReason"] == "end_turn" + + # Verify content + text_delta = content_delta["contentBlockDelta"]["delta"]["text"] + assert text_delta == "Paris is the capital of France." + + @pytest.mark.asyncio + async def test_stream_non_streaming(self, sagemaker_client, model, messages): + """Test non-streaming response.""" + # Configure model for non-streaming + model.payload_config["stream"] = False + + # Mock the response from SageMaker + mock_response = {"Body": unittest.mock.MagicMock()} + mock_response["Body"].read.return_value = json.dumps( + { + "choices": [ + { + "message": {"content": "Paris is the capital of France.", "tool_calls": None}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "prompt_tokens_details": 0}, + } + ).encode("utf-8") + + sagemaker_client.invoke_endpoint.return_value = mock_response + + response = [chunk async for chunk in model.stream(messages)] + + assert len(response) >= 6 + assert response[0] == {"messageStart": {"role": "assistant"}} + + # Find content events + content_start = next((e for e in response if "contentBlockStart" in e), None) + content_delta = next((e for e in response if "contentBlockDelta" in e), None) + content_stop = next((e for e in response if "contentBlockStop" in e), None) + message_stop = next((e for e in response if "messageStop" in e), None) + + assert content_start is not None + assert content_delta is not None + assert content_stop is not None + assert message_stop is not None + + # Verify content + text_delta = content_delta["contentBlockDelta"]["delta"]["text"] + assert text_delta == "Paris is the capital of France." + + sagemaker_client.invoke_endpoint.assert_called_once() + + @pytest.mark.asyncio + async def test_stream_non_streaming_with_tool_calls(self, sagemaker_client, model, messages): + """Test non-streaming response with tool calls.""" + # Configure model for non-streaming + model.payload_config["stream"] = False + + # Mock the response from SageMaker with tool calls + mock_response = {"Body": unittest.mock.MagicMock()} + mock_response["Body"].read.return_value = json.dumps( + { + "choices": [ + { + "message": { + "content": None, + "tool_calls": [ + { + "id": "tool123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "Paris"}'}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "prompt_tokens_details": 0}, + } + ).encode("utf-8") + + sagemaker_client.invoke_endpoint.return_value = mock_response + + response = [chunk async for chunk in model.stream(messages)] + + # Verify basic structure + assert len(response) >= 6 + assert response[0] == {"messageStart": {"role": "assistant"}} + + # Find tool call events + tool_start = next( + ( + e + for e in response + if "contentBlockStart" in e and e.get("contentBlockStart", {}).get("start", {}).get("toolUse") + ), + None, + ) + tool_delta = next( + ( + e + for e in response + if "contentBlockDelta" in e and e.get("contentBlockDelta", {}).get("delta", {}).get("toolUse") + ), + None, + ) + tool_stop = next((e for e in response if "contentBlockStop" in e), None) + message_stop = next((e for e in response if "messageStop" in e), None) + + assert tool_start is not None + assert tool_delta is not None + assert tool_stop is not None + assert message_stop is not None + + # Verify tool call data + tool_use_data = tool_start["contentBlockStart"]["start"]["toolUse"] + assert tool_use_data["toolUseId"] == "tool123" + assert tool_use_data["name"] == "get_weather" + + # Verify metadata + metadata = next((e for e in response if "metadata" in e), None) + assert metadata is not None + usage_data = metadata["metadata"]["usage"] + assert usage_data["totalTokens"] == 30 + + +class TestDataClasses: + """Test suite for data classes.""" + + def test_usage_metadata(self): + """Test UsageMetadata dataclass.""" + usage = UsageMetadata(total_tokens=100, completion_tokens=30, prompt_tokens=70, prompt_tokens_details=5) + + assert usage.total_tokens == 100 + assert usage.completion_tokens == 30 + assert usage.prompt_tokens == 70 + assert usage.prompt_tokens_details == 5 + + def test_function_call(self): + """Test FunctionCall dataclass.""" + func = FunctionCall(name="get_weather", arguments='{"location": "Paris"}') + + assert func.name == "get_weather" + assert func.arguments == '{"location": "Paris"}' + + # Test initialization with kwargs + func2 = FunctionCall(**{"name": "get_time", "arguments": '{"timezone": "UTC"}'}) + + assert func2.name == "get_time" + assert func2.arguments == '{"timezone": "UTC"}' + + def test_tool_call(self): + """Test ToolCall dataclass.""" + # Create a tool call using kwargs directly + tool = ToolCall( + id="tool123", type="function", function={"name": "get_weather", "arguments": '{"location": "Paris"}'} + ) + + assert tool.id == "tool123" + assert tool.type == "function" + assert tool.function.name == "get_weather" + assert tool.function.arguments == '{"location": "Paris"}' + + # Test initialization with kwargs + tool2 = ToolCall( + **{ + "id": "tool456", + "type": "function", + "function": {"name": "get_time", "arguments": '{"timezone": "UTC"}'}, + } + ) + + assert tool2.id == "tool456" + assert tool2.type == "function" + assert tool2.function.name == "get_time" + assert tool2.function.arguments == '{"timezone": "UTC"}' diff --git a/tests_integ/models/test_model_sagemaker.py b/tests_integ/models/test_model_sagemaker.py new file mode 100644 index 000000000..62362e299 --- /dev/null +++ b/tests_integ/models/test_model_sagemaker.py @@ -0,0 +1,76 @@ +import os + +import pytest + +import strands +from strands import Agent +from strands.models.sagemaker import SageMakerAIModel + + +@pytest.fixture +def model(): + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name=os.getenv("SAGEMAKER_ENDPOINT_NAME", ""), region_name="us-east-1" + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024, temperature=0.7, stream=False) + return SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time(location: str) -> str: + """Get the current time for a location.""" + return f"The time in {location} is 12:00 PM" + + @strands.tool + def tool_weather(location: str) -> str: + """Get the current weather for a location.""" + return f"The weather in {location} is sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant that provides concise answers." + + +@pytest.fixture +def agent(model, tools, system_prompt): + return Agent(model=model, tools=tools, system_prompt=system_prompt) + + +@pytest.mark.skipif( + "SAGEMAKER_ENDPOINT_NAME" not in os.environ, + reason="SAGEMAKER_ENDPOINT_NAME environment variable missing", +) +def test_agent_with_tools(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert "12:00" in text and "sunny" in text + + +@pytest.mark.skipif( + "SAGEMAKER_ENDPOINT_NAME" not in os.environ, + reason="SAGEMAKER_ENDPOINT_NAME environment variable missing", +) +def test_agent_without_tools(model, system_prompt): + agent = Agent(model=model, system_prompt=system_prompt) + result = agent("Hello, how are you?") + + assert result.message["content"][0]["text"] + assert len(result.message["content"][0]["text"]) > 0 + + +@pytest.mark.skipif( + "SAGEMAKER_ENDPOINT_NAME" not in os.environ, + reason="SAGEMAKER_ENDPOINT_NAME environment variable missing", +) +@pytest.mark.parametrize("location", ["Tokyo", "London", "Sydney"]) +def test_agent_different_locations(agent, location): + result = agent(f"What is the weather in {location}?") + text = result.message["content"][0]["text"].lower() + + assert location.lower() in text and "sunny" in text From 3f4c3a35ce14800e4852998e0c2b68f90295ffb7 Mon Sep 17 00:00:00 2001 From: mehtarac Date: Mon, 28 Jul 2025 10:23:43 -0400 Subject: [PATCH 012/104] fix: Remove leftover print statement from sagemaker model provider (#553) --- src/strands/models/sagemaker.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index bb2db45a2..9cfe27d9e 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -274,8 +274,6 @@ def format_request( if self.endpoint_config.get("additional_args"): request.update(self.endpoint_config["additional_args"].__dict__) - print(json.dumps(request["Body"], indent=2)) - return request @override From bdc893bbae711c1af301e6f18901cb30814789a0 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 29 Jul 2025 14:41:57 -0400 Subject: [PATCH 013/104] [Feat] Update structured output error message (#563) * Update bedrock.py * Update anthropic.py --- src/strands/models/anthropic.py | 2 +- src/strands/models/bedrock.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index eb72becfd..0d734b762 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -414,7 +414,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") + raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 679f1ea3d..cf1e4d3a9 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -584,7 +584,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") content = messages["content"] output_response: dict[str, Any] | None = None From 4e0e0a648c7e441ce15eacca213b7b65e982fd3b Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 29 Jul 2025 18:03:19 -0400 Subject: [PATCH 014/104] feat(mcp): retain structured content in the AgentTool response (#528) --- pyproject.toml | 2 +- src/strands/models/bedrock.py | 53 +++++++++- src/strands/tools/mcp/mcp_client.py | 49 +++++++--- src/strands/tools/mcp/mcp_types.py | 20 ++++ tests/strands/models/test_bedrock.py | 96 ++++++++++++------- tests/strands/tools/mcp/test_mcp_client.py | 67 +++++++++++++ tests_integ/echo_server.py | 16 +++- tests_integ/test_mcp_client.py | 77 +++++++++++++++ ...cp_client_structured_content_with_hooks.py | 65 +++++++++++++ 9 files changed, 389 insertions(+), 56 deletions(-) create mode 100644 tests_integ/test_mcp_client_structured_content_with_hooks.py diff --git a/pyproject.toml b/pyproject.toml index 745c80e0c..095a38cb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", "docstring_parser>=0.15,<1.0", - "mcp>=1.8.0,<2.0.0", + "mcp>=1.11.0,<2.0.0", "pydantic>=2.0.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", "watchdog>=6.0.0,<7.0.0", diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index cf1e4d3a9..9b36b4244 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -17,10 +17,10 @@ from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec -from ..types.content import Messages +from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec +from ..types.tools import ToolResult, ToolSpec from .model import Model logger = logging.getLogger(__name__) @@ -181,7 +181,7 @@ def format_request( """ return { "modelId": self.config["model_id"], - "messages": messages, + "messages": self._format_bedrock_messages(messages), "system": [ *([{"text": system_prompt}] if system_prompt else []), *([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []), @@ -246,6 +246,53 @@ def format_request( ), } + def _format_bedrock_messages(self, messages: Messages) -> Messages: + """Format messages for Bedrock API compatibility. + + This function ensures messages conform to Bedrock's expected format by: + - Cleaning tool result content blocks by removing additional fields that may be + useful for retaining information in hooks but would cause Bedrock validation + exceptions when presented with unexpected fields + - Ensuring all message content blocks are properly formatted for the Bedrock API + + Args: + messages: List of messages to format + + Returns: + Messages formatted for Bedrock API compatibility + + Note: + Bedrock will throw validation exceptions when presented with additional + unexpected fields in tool result blocks. + https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html + """ + cleaned_messages = [] + + for message in messages: + cleaned_content: list[ContentBlock] = [] + + for content_block in message["content"]: + if "toolResult" in content_block: + # Create a new content block with only the cleaned toolResult + tool_result: ToolResult = content_block["toolResult"] + + # Keep only the required fields for Bedrock + cleaned_tool_result = ToolResult( + content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"] + ) + + cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result} + cleaned_content.append(cleaned_block) + else: + # Keep other content blocks as-is + cleaned_content.append(content_block) + + # Create new message with cleaned content + cleaned_message: Message = Message(content=cleaned_content, role=message["role"]) + cleaned_messages.append(cleaned_message) + + return cleaned_messages + def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: """Check if guardrail data contains any blocked policies. diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 4cf4e1f85..784636fd0 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -26,9 +26,9 @@ from ...types import PaginatedList from ...types.exceptions import MCPClientInitializationError from ...types.media import ImageFormat -from ...types.tools import ToolResult, ToolResultContent, ToolResultStatus +from ...types.tools import ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool -from .mcp_types import MCPTransport +from .mcp_types import MCPToolResult, MCPTransport logger = logging.getLogger(__name__) @@ -57,7 +57,8 @@ class MCPClient: It handles the creation, initialization, and cleanup of MCP connections. The connection runs in a background thread to avoid blocking the main application thread - while maintaining communication with the MCP service. + while maintaining communication with the MCP service. When structured content is available + from MCP tools, it will be returned as the last item in the content array of the ToolResult. """ def __init__(self, transport_callable: Callable[[], MCPTransport]): @@ -170,11 +171,13 @@ def call_tool_sync( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, - ) -> ToolResult: + ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. This method calls the asynchronous call_tool method on the MCP session - and converts the result to the ToolResult format. + and converts the result to the ToolResult format. If the MCP tool returns + structured content, it will be included as the last item in the content array + of the returned ToolResult. Args: tool_use_id: Unique identifier for this tool use @@ -183,7 +186,7 @@ def call_tool_sync( read_timeout_seconds: Optional timeout for the tool call Returns: - ToolResult: The result of the tool call + MCPToolResult: The result of the tool call """ self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id) if not self._is_session_active(): @@ -205,11 +208,11 @@ async def call_tool_async( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, - ) -> ToolResult: + ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. This method calls the asynchronous call_tool method on the MCP session - and converts the result to the ToolResult format. + and converts the result to the MCPToolResult format. Args: tool_use_id: Unique identifier for this tool use @@ -218,7 +221,7 @@ async def call_tool_async( read_timeout_seconds: Optional timeout for the tool call Returns: - ToolResult: The result of the tool call + MCPToolResult: The result of the tool call """ self._log_debug_with_thread("calling MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id) if not self._is_session_active(): @@ -235,15 +238,27 @@ async def _call_tool_async() -> MCPCallToolResult: logger.exception("tool execution failed") return self._handle_tool_execution_error(tool_use_id, e) - def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> ToolResult: + def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult: """Create error ToolResult with consistent logging.""" - return ToolResult( + return MCPToolResult( status="error", toolUseId=tool_use_id, content=[{"text": f"Tool execution failed: {str(exception)}"}], ) - def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> ToolResult: + def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> MCPToolResult: + """Maps MCP tool result to the agent's MCPToolResult format. + + This method processes the content from the MCP tool call result and converts it to the format + expected by the framework. + + Args: + tool_use_id: Unique identifier for this tool use + call_tool_result: The result from the MCP tool call + + Returns: + MCPToolResult: The converted tool result + """ self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content)) mapped_content = [ @@ -254,7 +269,15 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes status: ToolResultStatus = "error" if call_tool_result.isError else "success" self._log_debug_with_thread("tool execution completed with status: %s", status) - return ToolResult(status=status, toolUseId=tool_use_id, content=mapped_content) + result = MCPToolResult( + status=status, + toolUseId=tool_use_id, + content=mapped_content, + ) + if call_tool_result.structuredContent: + result["structuredContent"] = call_tool_result.structuredContent + + return result async def _async_background_thread(self) -> None: """Asynchronous method that runs in the background thread to manage the MCP connection. diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py index 30defc585..5fafed5dc 100644 --- a/src/strands/tools/mcp/mcp_types.py +++ b/src/strands/tools/mcp/mcp_types.py @@ -1,11 +1,15 @@ """Type definitions for MCP integration.""" from contextlib import AbstractAsyncContextManager +from typing import Any, Dict from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.client.streamable_http import GetSessionIdCallback from mcp.shared.memory import MessageStream from mcp.shared.message import SessionMessage +from typing_extensions import NotRequired + +from strands.types.tools import ToolResult """ MCPTransport defines the interface for MCP transport implementations. This abstracts @@ -41,3 +45,19 @@ async def my_transport_implementation(): MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback ] MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback] + + +class MCPToolResult(ToolResult): + """Result of an MCP tool execution. + + Extends the base ToolResult with MCP-specific structured content support. + The structuredContent field contains optional JSON data returned by MCP tools + that provides structured results beyond the standard text/image/document content. + + Attributes: + structuredContent: Optional JSON object containing structured data returned + by the MCP tool. This allows MCP tools to return complex data structures + that can be processed programmatically by agents or other tools. + """ + + structuredContent: NotRequired[Dict[str, Any]] diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 47e028cb9..0a2846adf 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -13,6 +13,7 @@ from strands.models import BedrockModel from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION from strands.types.exceptions import ModelThrottledException +from strands.types.tools import ToolSpec @pytest.fixture @@ -51,7 +52,7 @@ def model(bedrock_client, model_id): @pytest.fixture def messages(): - return [{"role": "user", "content": {"text": "test"}}] + return [{"role": "user", "content": [{"text": "test"}]}] @pytest.fixture @@ -90,8 +91,12 @@ def inference_config(): @pytest.fixture -def tool_spec(): - return {"t1": 1} +def tool_spec() -> ToolSpec: + return { + "description": "description", + "name": "name", + "inputSchema": {"key": "val"}, + } @pytest.fixture @@ -750,7 +755,7 @@ async def test_stream_output_no_guardrail_redact( @pytest.mark.asyncio -async def test_stream_with_streaming_false(bedrock_client, alist): +async def test_stream_with_streaming_false(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -759,8 +764,7 @@ async def test_stream_with_streaming_false(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -776,7 +780,7 @@ async def test_stream_with_streaming_false(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist): +async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -790,8 +794,7 @@ async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -808,7 +811,7 @@ async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist): +async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -828,8 +831,7 @@ async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -847,7 +849,7 @@ async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_and_reasoning_no_signature(bedrock_client, alist): +async def test_stream_and_reasoning_no_signature(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -867,8 +869,7 @@ async def test_stream_and_reasoning_no_signature(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -884,7 +885,7 @@ async def test_stream_and_reasoning_no_signature(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist): +async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -895,8 +896,7 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -919,7 +919,7 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client @pytest.mark.asyncio -async def test_stream_input_guardrails(bedrock_client, alist): +async def test_stream_input_guardrails(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -937,8 +937,7 @@ async def test_stream_input_guardrails(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -970,7 +969,7 @@ async def test_stream_input_guardrails(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_output_guardrails(bedrock_client, alist): +async def test_stream_output_guardrails(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -989,8 +988,7 @@ async def test_stream_output_guardrails(bedrock_client, alist): } model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -1024,7 +1022,7 @@ async def test_stream_output_guardrails(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_output_guardrails_redacts_output(bedrock_client, alist): +async def test_stream_output_guardrails_redacts_output(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -1043,8 +1041,7 @@ async def test_stream_output_guardrails_redacts_output(bedrock_client, alist): } model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -1101,7 +1098,7 @@ async def test_structured_output(bedrock_client, model, test_output_model_cls, a @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") @pytest.mark.asyncio -async def test_add_note_on_client_error(bedrock_client, model, alist): +async def test_add_note_on_client_error(bedrock_client, model, alist, messages): """Test that add_note is called on ClientError with region and model ID information.""" # Mock the client error response error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} @@ -1109,13 +1106,13 @@ async def test_add_note_on_client_error(bedrock_client, model, alist): # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError) as err: - await alist(model.stream({"modelId": "test-model"})) + await alist(model.stream(messages)) assert err.value.__notes__ == ["ā”” Bedrock region: us-west-2", "ā”” Model id: m1"] @pytest.mark.asyncio -async def test_no_add_note_when_not_available(bedrock_client, model, alist): +async def test_no_add_note_when_not_available(bedrock_client, model, alist, messages): """Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception).""" # Mock the client error response error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} @@ -1123,12 +1120,12 @@ async def test_no_add_note_when_not_available(bedrock_client, model, alist): # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError): - await alist(model.stream({"modelId": "test-model"})) + await alist(model.stream(messages)) @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") @pytest.mark.asyncio -async def test_add_note_on_access_denied_exception(bedrock_client, model, alist): +async def test_add_note_on_access_denied_exception(bedrock_client, model, alist, messages): """Test that add_note adds documentation link for AccessDeniedException.""" # Mock the client error response for access denied error_response = { @@ -1142,7 +1139,7 @@ async def test_add_note_on_access_denied_exception(bedrock_client, model, alist) # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError) as err: - await alist(model.stream({"modelId": "test-model"})) + await alist(model.stream(messages)) assert err.value.__notes__ == [ "ā”” Bedrock region: us-west-2", @@ -1154,7 +1151,7 @@ async def test_add_note_on_access_denied_exception(bedrock_client, model, alist) @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") @pytest.mark.asyncio -async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist): +async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist, messages): """Test that add_note adds documentation link for ValidationException about on-demand throughput.""" # Mock the client error response for validation exception error_response = { @@ -1170,7 +1167,7 @@ async def test_add_note_on_validation_exception_throughput(bedrock_client, model # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError) as err: - await alist(model.stream({"modelId": "test-model"})) + await alist(model.stream(messages)) assert err.value.__notes__ == [ "ā”” Bedrock region: us-west-2", @@ -1202,3 +1199,32 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist): assert "invoking model" in log_text assert "got response from model" in log_text assert "finished streaming response from model" in log_text + + +def test_format_request_cleans_tool_result_content_blocks(model, model_id): + """Test that format_request cleans toolResult blocks by removing extra fields.""" + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "content": [{"text": "Tool output"}], + "toolUseId": "tool123", + "status": "success", + "extraField": "should be removed", + "mcpMetadata": {"server": "test"}, + } + }, + ], + } + ] + + formatted_request = model.format_request(messages) + + # Verify toolResult only contains allowed fields in the formatted request + tool_result = formatted_request["messages"][0]["content"][0]["toolResult"] + expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool123", "status": "success"} + assert tool_result == expected + assert "extraField" not in tool_result + assert "mcpMetadata" not in tool_result diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 6a2fdd00c..3d3792c71 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -8,6 +8,7 @@ from mcp.types import Tool as MCPTool from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_types import MCPToolResult from strands.types.exceptions import MCPClientInitializationError @@ -129,6 +130,8 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_ assert result["toolUseId"] == "test-123" assert len(result["content"]) == 1 assert result["content"][0]["text"] == "Test message" + # No structured content should be present when not provided by MCP + assert result.get("structuredContent") is None def test_call_tool_sync_session_not_active(): @@ -139,6 +142,31 @@ def test_call_tool_sync_session_not_active(): client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) +def test_call_tool_sync_with_structured_content(mock_transport, mock_session): + """Test that call_tool_sync correctly handles structured content.""" + mock_content = MCPTextContent(type="text", text="Test message") + structured_content = {"result": 42, "status": "completed"} + mock_session.call_tool.return_value = MCPCallToolResult( + isError=False, content=[mock_content], structuredContent=structured_content + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) + + assert result["status"] == "success" + assert result["toolUseId"] == "test-123" + # Content should only contain the text content, not the structured content + assert len(result["content"]) == 1 + assert result["content"][0]["text"] == "Test message" + # Structured content should be in its own field + assert "structuredContent" in result + assert result["structuredContent"] == structured_content + assert result["structuredContent"]["result"] == 42 + assert result["structuredContent"]["status"] == "completed" + + def test_call_tool_sync_exception(mock_transport, mock_session): """Test that call_tool_sync correctly handles exceptions.""" mock_session.call_tool.side_effect = Exception("Test exception") @@ -312,6 +340,45 @@ def test_enter_with_initialization_exception(mock_transport): client.start() +def test_mcp_tool_result_type(): + """Test that MCPToolResult extends ToolResult correctly.""" + # Test basic ToolResult functionality + result = MCPToolResult(status="success", toolUseId="test-123", content=[{"text": "Test message"}]) + + assert result["status"] == "success" + assert result["toolUseId"] == "test-123" + assert result["content"][0]["text"] == "Test message" + + # Test that structuredContent is optional + assert "structuredContent" not in result or result.get("structuredContent") is None + + # Test with structuredContent + result_with_structured = MCPToolResult( + status="success", toolUseId="test-456", content=[{"text": "Test message"}], structuredContent={"key": "value"} + ) + + assert result_with_structured["structuredContent"] == {"key": "value"} + + +def test_call_tool_sync_without_structured_content(mock_transport, mock_session): + """Test that call_tool_sync works correctly when no structured content is provided.""" + mock_content = MCPTextContent(type="text", text="Test message") + mock_session.call_tool.return_value = MCPCallToolResult( + isError=False, + content=[mock_content], # No structuredContent + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + assert result["status"] == "success" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert result["content"][0]["text"] == "Test message" + # structuredContent should be None when not provided by MCP + assert result.get("structuredContent") is None + + def test_exception_when_future_not_running(): """Test exception handling when the future is not running.""" # Create a client.with a mock transport diff --git a/tests_integ/echo_server.py b/tests_integ/echo_server.py index d309607a8..52223792c 100644 --- a/tests_integ/echo_server.py +++ b/tests_integ/echo_server.py @@ -2,7 +2,7 @@ Echo Server for MCP Integration Testing This module implements a simple echo server using the Model Context Protocol (MCP). -It provides a basic tool that echoes back any input string, which is useful for +It provides basic tools that echo back input strings and structured content, which is useful for testing the MCP communication flow and validating that messages are properly transmitted between the client and server. @@ -15,6 +15,8 @@ $ python echo_server.py """ +from typing import Any, Dict + from mcp.server import FastMCP @@ -22,16 +24,22 @@ def start_echo_server(): """ Initialize and start the MCP echo server. - Creates a FastMCP server instance with a single 'echo' tool that returns - any input string back to the caller. The server uses stdio transport + Creates a FastMCP server instance with tools that return + input strings and structured content back to the caller. The server uses stdio transport for communication. + """ mcp = FastMCP("Echo Server") - @mcp.tool(description="Echos response back to the user") + @mcp.tool(description="Echos response back to the user", structured_output=False) def echo(to_echo: str) -> str: return to_echo + # FastMCP automatically constructs structured output schema from method signature + @mcp.tool(description="Echos response back with structured content", structured_output=True) + def echo_with_structured_content(to_echo: str) -> Dict[str, Any]: + return {"echoed": to_echo} + mcp.run(transport="stdio") diff --git a/tests_integ/test_mcp_client.py b/tests_integ/test_mcp_client.py index 9163f625d..ebd4f5896 100644 --- a/tests_integ/test_mcp_client.py +++ b/tests_integ/test_mcp_client.py @@ -1,4 +1,5 @@ import base64 +import json import os import threading import time @@ -87,6 +88,24 @@ def test_mcp_client(): ] ) + tool_use_id = "test-structured-content-123" + result = stdio_mcp_client.call_tool_sync( + tool_use_id=tool_use_id, + name="echo_with_structured_content", + arguments={"to_echo": "STRUCTURED_DATA_TEST"}, + ) + + # With the new MCPToolResult, structured content is in its own field + assert "structuredContent" in result + assert result["structuredContent"]["result"] == {"echoed": "STRUCTURED_DATA_TEST"} + + # Verify the result is an MCPToolResult (at runtime it's just a dict, but type-wise it should be MCPToolResult) + assert result["status"] == "success" + assert result["toolUseId"] == tool_use_id + + assert len(result["content"]) == 1 + assert json.loads(result["content"][0]["text"]) == {"echoed": "STRUCTURED_DATA_TEST"} + def test_can_reuse_mcp_client(): stdio_mcp_client = MCPClient( @@ -103,6 +122,64 @@ def test_can_reuse_mcp_client(): assert any([block["name"] == "echo" for block in tool_use_content_blocks]) +@pytest.mark.asyncio +async def test_mcp_client_async_structured_content(): + """Test that async MCP client calls properly handle structured content. + + This test demonstrates how tools configure structured output: FastMCP automatically + constructs structured output schema from method signature when structured_output=True + is set in the @mcp.tool decorator. The return type annotation defines the structure + that appears in structuredContent field. + """ + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + with stdio_mcp_client: + tool_use_id = "test-async-structured-content-456" + result = await stdio_mcp_client.call_tool_async( + tool_use_id=tool_use_id, + name="echo_with_structured_content", + arguments={"to_echo": "ASYNC_STRUCTURED_TEST"}, + ) + + # Verify structured content is in its own field + assert "structuredContent" in result + # "result" nesting is not part of the MCP Structured Content specification, + # but rather a FastMCP implementation detail + assert result["structuredContent"]["result"] == {"echoed": "ASYNC_STRUCTURED_TEST"} + + # Verify basic MCPToolResult structure + assert result["status"] in ["success", "error"] + assert result["toolUseId"] == tool_use_id + + assert len(result["content"]) == 1 + assert json.loads(result["content"][0]["text"]) == {"echoed": "ASYNC_STRUCTURED_TEST"} + + +def test_mcp_client_without_structured_content(): + """Test that MCP client works correctly when tools don't return structured content.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + with stdio_mcp_client: + tool_use_id = "test-no-structured-content-789" + result = stdio_mcp_client.call_tool_sync( + tool_use_id=tool_use_id, + name="echo", # This tool doesn't return structured content + arguments={"to_echo": "SIMPLE_ECHO_TEST"}, + ) + + # Verify no structured content when tool doesn't provide it + assert result.get("structuredContent") is None + + # Verify basic result structure + assert result["status"] == "success" + assert result["toolUseId"] == tool_use_id + assert result["content"] == [{"text": "SIMPLE_ECHO_TEST"}] + + @pytest.mark.skipif( condition=os.environ.get("GITHUB_ACTIONS") == "true", reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", diff --git a/tests_integ/test_mcp_client_structured_content_with_hooks.py b/tests_integ/test_mcp_client_structured_content_with_hooks.py new file mode 100644 index 000000000..ca2468c48 --- /dev/null +++ b/tests_integ/test_mcp_client_structured_content_with_hooks.py @@ -0,0 +1,65 @@ +"""Integration test demonstrating hooks system with MCP client structured content tool. + +This test shows how to use the hooks system to capture and inspect tool invocation +results, specifically testing the echo_with_structured_content tool from echo_server. +""" + +import json + +from mcp import StdioServerParameters, stdio_client + +from strands import Agent +from strands.experimental.hooks import AfterToolInvocationEvent +from strands.hooks import HookProvider, HookRegistry +from strands.tools.mcp.mcp_client import MCPClient + + +class StructuredContentHookProvider(HookProvider): + """Hook provider that captures structured content tool results.""" + + def __init__(self): + self.captured_result = None + + def register_hooks(self, registry: HookRegistry) -> None: + """Register callback for after tool invocation events.""" + registry.add_callback(AfterToolInvocationEvent, self.on_after_tool_invocation) + + def on_after_tool_invocation(self, event: AfterToolInvocationEvent) -> None: + """Capture structured content tool results.""" + if event.tool_use["name"] == "echo_with_structured_content": + self.captured_result = event.result + + +def test_mcp_client_hooks_structured_content(): + """Test using hooks to inspect echo_with_structured_content tool result.""" + # Create hook provider to capture tool result + hook_provider = StructuredContentHookProvider() + + # Set up MCP client for echo server + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + with stdio_mcp_client: + # Create agent with MCP tools and hook provider + agent = Agent(tools=stdio_mcp_client.list_tools_sync(), hooks=[hook_provider]) + + # Test structured content functionality + test_data = "HOOKS_TEST_DATA" + agent(f"Use the echo_with_structured_content tool to echo: {test_data}") + + # Verify hook captured the tool result + assert hook_provider.captured_result is not None + result = hook_provider.captured_result + + # Verify basic result structure + assert result["status"] == "success" + assert len(result["content"]) == 1 + + # Verify structured content is present and correct + assert "structuredContent" in result + assert result["structuredContent"]["result"] == {"echoed": test_data} + + # Verify text content matches structured content + text_content = json.loads(result["content"][0]["text"]) + assert text_content == {"echoed": test_data} From b13c5c5492e7745acb86d23eb215acdce0120361 Mon Sep 17 00:00:00 2001 From: Ketan Suhaas Saichandran <55935983+Ketansuhaas@users.noreply.github.com> Date: Wed, 30 Jul 2025 08:59:29 -0400 Subject: [PATCH 015/104] feat(mcp): Add list_prompts, get_prompt methods (#160) Co-authored-by: ketan-clairyon Co-authored-by: Dean Schmigelski --- src/strands/tools/mcp/mcp_client.py | 49 +++++++++++++ tests/strands/tools/mcp/test_mcp_client.py | 62 ++++++++++++++++ tests_integ/test_mcp_client.py | 83 +++++++++++++++++++--- 3 files changed, 184 insertions(+), 10 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 784636fd0..8c21baa4a 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -20,6 +20,7 @@ from mcp import ClientSession, ListToolsResult from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import GetPromptResult, ListPromptsResult from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent @@ -165,6 +166,54 @@ async def _list_tools_async() -> ListToolsResult: self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) + def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromptsResult: + """Synchronously retrieves the list of available prompts from the MCP server. + + This method calls the asynchronous list_prompts method on the MCP session + and returns the raw ListPromptsResult with pagination support. + + Args: + pagination_token: Optional token for pagination + + Returns: + ListPromptsResult: The raw MCP response containing prompts and pagination info + """ + self._log_debug_with_thread("listing MCP prompts synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _list_prompts_async() -> ListPromptsResult: + return await self._background_thread_session.list_prompts(cursor=pagination_token) + + list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result() + self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts)) + for prompt in list_prompts_result.prompts: + self._log_debug_with_thread(prompt.name) + + return list_prompts_result + + def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResult: + """Synchronously retrieves a prompt from the MCP server. + + Args: + prompt_id: The ID of the prompt to retrieve + args: Optional arguments to pass to the prompt + + Returns: + GetPromptResult: The prompt response from the MCP server + """ + self._log_debug_with_thread("getting MCP prompt synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _get_prompt_async() -> GetPromptResult: + return await self._background_thread_session.get_prompt(prompt_id, arguments=args) + + get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result() + self._log_debug_with_thread("received prompt from MCP server") + + return get_prompt_result + def call_tool_sync( self, tool_use_id: str, diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 3d3792c71..bd88382cd 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -4,6 +4,7 @@ import pytest from mcp import ListToolsResult from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import GetPromptResult, ListPromptsResult, Prompt, PromptMessage from mcp.types import TextContent as MCPTextContent from mcp.types import Tool as MCPTool @@ -404,3 +405,64 @@ def test_exception_when_future_not_running(): # Verify that set_exception was not called since the future was not running mock_future.set_exception.assert_not_called() + + +# Prompt Tests - Sync Methods + + +def test_list_prompts_sync(mock_transport, mock_session): + """Test that list_prompts_sync correctly retrieves prompts.""" + mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1") + mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_prompts_sync() + + mock_session.list_prompts.assert_called_once_with(cursor=None) + assert len(result.prompts) == 1 + assert result.prompts[0].name == "test_prompt" + assert result.nextCursor is None + + +def test_list_prompts_sync_with_pagination_token(mock_transport, mock_session): + """Test that list_prompts_sync correctly passes pagination token and returns next cursor.""" + mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1") + mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt], nextCursor="next_page_token") + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_prompts_sync(pagination_token="current_page_token") + + mock_session.list_prompts.assert_called_once_with(cursor="current_page_token") + assert len(result.prompts) == 1 + assert result.prompts[0].name == "test_prompt" + assert result.nextCursor == "next_page_token" + + +def test_list_prompts_sync_session_not_active(): + """Test that list_prompts_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.list_prompts_sync() + + +def test_get_prompt_sync(mock_transport, mock_session): + """Test that get_prompt_sync correctly retrieves a prompt.""" + mock_message = PromptMessage(role="user", content=MCPTextContent(type="text", text="This is a test prompt")) + mock_session.get_prompt.return_value = GetPromptResult(messages=[mock_message]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.get_prompt_sync("test_prompt_id", {"key": "value"}) + + mock_session.get_prompt.assert_called_once_with("test_prompt_id", arguments={"key": "value"}) + assert len(result.messages) == 1 + assert result.messages[0].role == "user" + assert result.messages[0].content.text == "This is a test prompt" + + +def test_get_prompt_sync_session_not_active(): + """Test that get_prompt_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.get_prompt_sync("test_prompt_id", {}) diff --git a/tests_integ/test_mcp_client.py b/tests_integ/test_mcp_client.py index ebd4f5896..3de249435 100644 --- a/tests_integ/test_mcp_client.py +++ b/tests_integ/test_mcp_client.py @@ -18,18 +18,17 @@ from strands.types.tools import ToolUse -def start_calculator_server(transport: Literal["sse", "streamable-http"], port=int): +def start_comprehensive_mcp_server(transport: Literal["sse", "streamable-http"], port=int): """ - Initialize and start an MCP calculator server for integration testing. + Initialize and start a comprehensive MCP server for integration testing. - This function creates a FastMCP server instance that provides a simple - calculator tool for performing addition operations. The server uses - Server-Sent Events (SSE) transport for communication, making it accessible - over HTTP. + This function creates a FastMCP server instance that provides tools, prompts, + and resources all in one server for comprehensive testing. The server uses + Server-Sent Events (SSE) or streamable HTTP transport for communication. """ from mcp.server import FastMCP - mcp = FastMCP("Calculator Server", port=port) + mcp = FastMCP("Comprehensive MCP Server", port=port) @mcp.tool(description="Calculator tool which performs calculations") def calculator(x: int, y: int) -> int: @@ -44,6 +43,15 @@ def generate_custom_image() -> MCPImageContent: except Exception as e: print("Error while generating custom image: {}".format(e)) + # Prompts + @mcp.prompt(description="A greeting prompt template") + def greeting_prompt(name: str = "World") -> str: + return f"Hello, {name}! How are you today?" + + @mcp.prompt(description="A math problem prompt template") + def math_prompt(operation: str = "addition", difficulty: str = "easy") -> str: + return f"Create a {difficulty} {operation} math problem and solve it step by step." + mcp.run(transport=transport) @@ -58,8 +66,9 @@ def test_mcp_client(): {'role': 'assistant', 'content': [{'text': '\n\nThe result of adding 1 and 2 is 3.'}]} """ # noqa: E501 + # Start comprehensive server with tools, prompts, and resources server_thread = threading.Thread( - target=start_calculator_server, kwargs={"transport": "sse", "port": 8000}, daemon=True + target=start_comprehensive_mcp_server, kwargs={"transport": "sse", "port": 8000}, daemon=True ) server_thread.start() time.sleep(2) # wait for server to startup completely @@ -68,8 +77,14 @@ def test_mcp_client(): stdio_mcp_client = MCPClient( lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) ) + with sse_mcp_client, stdio_mcp_client: - agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync()) + # Test Tools functionality + sse_tools = sse_mcp_client.list_tools_sync() + stdio_tools = stdio_mcp_client.list_tools_sync() + all_tools = sse_tools + stdio_tools + + agent = Agent(tools=all_tools) agent("add 1 and 2, then echo the result back to me") tool_use_content_blocks = _messages_to_content_blocks(agent.messages) @@ -88,6 +103,43 @@ def test_mcp_client(): ] ) + # Test Prompts functionality + prompts_result = sse_mcp_client.list_prompts_sync() + assert len(prompts_result.prompts) >= 2 # We expect at least greeting and math prompts + + prompt_names = [prompt.name for prompt in prompts_result.prompts] + assert "greeting_prompt" in prompt_names + assert "math_prompt" in prompt_names + + # Test get_prompt_sync with greeting prompt + greeting_result = sse_mcp_client.get_prompt_sync("greeting_prompt", {"name": "Alice"}) + assert len(greeting_result.messages) > 0 + prompt_text = greeting_result.messages[0].content.text + assert "Hello, Alice!" in prompt_text + assert "How are you today?" in prompt_text + + # Test get_prompt_sync with math prompt + math_result = sse_mcp_client.get_prompt_sync( + "math_prompt", {"operation": "multiplication", "difficulty": "medium"} + ) + assert len(math_result.messages) > 0 + math_text = math_result.messages[0].content.text + assert "multiplication" in math_text + assert "medium" in math_text + assert "step by step" in math_text + + # Test pagination support for prompts + prompts_with_token = sse_mcp_client.list_prompts_sync(pagination_token=None) + assert len(prompts_with_token.prompts) >= 0 + + # Test pagination support for tools (existing functionality) + tools_with_token = sse_mcp_client.list_tools_sync(pagination_token=None) + assert len(tools_with_token) >= 0 + + # TODO: Add resources testing when resources are implemented + # resources_result = sse_mcp_client.list_resources_sync() + # assert len(resources_result.resources) >= 0 + tool_use_id = "test-structured-content-123" result = stdio_mcp_client.call_tool_sync( tool_use_id=tool_use_id, @@ -185,8 +237,9 @@ def test_mcp_client_without_structured_content(): reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", ) def test_streamable_http_mcp_client(): + """Test comprehensive MCP client with streamable HTTP transport.""" server_thread = threading.Thread( - target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True + target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True ) server_thread.start() time.sleep(2) # wait for server to startup completely @@ -196,12 +249,22 @@ def transport_callback() -> MCPTransport: streamable_http_client = MCPClient(transport_callback) with streamable_http_client: + # Test tools agent = Agent(tools=streamable_http_client.list_tools_sync()) agent("add 1 and 2 using a calculator") tool_use_content_blocks = _messages_to_content_blocks(agent.messages) assert any([block["name"] == "calculator" for block in tool_use_content_blocks]) + # Test prompts + prompts_result = streamable_http_client.list_prompts_sync() + assert len(prompts_result.prompts) >= 2 + + greeting_result = streamable_http_client.get_prompt_sync("greeting_prompt", {"name": "Charlie"}) + assert len(greeting_result.messages) > 0 + prompt_text = greeting_result.messages[0].content.text + assert "Hello, Charlie!" in prompt_text + def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block] From 3d526f2e254d38bb83b8ec85af56e79e4e1fe33f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=BF=E3=81=AE=E3=82=8B=E3=82=93?= <74597894+minorun365@users.noreply.github.com> Date: Thu, 31 Jul 2025 23:40:25 +0900 Subject: [PATCH 016/104] fix(deps): pin a2a-sdk>=0.2.16 to resolve #572 (#581) Co-authored-by: Jeremiah --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 095a38cb0..cdf68e01f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ sagemaker = [ ] a2a = [ + "a2a-sdk>=0.2.16,<1.0.0", "a2a-sdk[sql]>=0.2.11,<1.0.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", @@ -321,4 +322,4 @@ style = [ ["instruction", ""], ["text", ""], ["disabled", "fg:#858585 italic"] -] \ No newline at end of file +] From b56a4ff32e93dd74a10c8895cd68528091e88f1b Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 1 Aug 2025 09:42:35 -0400 Subject: [PATCH 017/104] chore: pin a2a to a minor version while it is still in beta (#586) --- pyproject.toml | 6 +++--- src/strands/multiagent/a2a/executor.py | 2 +- tests/strands/multiagent/a2a/test_executor.py | 16 ++++++++-------- tests/strands/multiagent/a2a/test_server.py | 4 ++-- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cdf68e01f..586a956af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,8 +96,8 @@ sagemaker = [ ] a2a = [ - "a2a-sdk>=0.2.16,<1.0.0", - "a2a-sdk[sql]>=0.2.11,<1.0.0", + "a2a-sdk>=0.3.0,<0.4.0", + "a2a-sdk[sql]>=0.3.0,<0.4.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", "fastapi>=0.115.12,<1.0.0", @@ -143,7 +143,7 @@ all = [ "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", # a2a - "a2a-sdk[sql]>=0.2.11,<1.0.0", + "a2a-sdk[sql]>=0.3.0,<0.4.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", "fastapi>=0.115.12,<1.0.0", diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index d65c64aff..5bf9cbfe9 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -61,7 +61,7 @@ async def execute( task = new_task(context.message) # type: ignore await event_queue.enqueue_event(task) - updater = TaskUpdater(event_queue, task.id, task.contextId) + updater = TaskUpdater(event_queue, task.id, task.context_id) try: await self._execute_streaming(context, updater) diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index a956cb769..77645fc73 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -36,7 +36,7 @@ async def mock_stream(user_input): # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task await executor.execute(mock_request_context, mock_event_queue) @@ -65,7 +65,7 @@ async def mock_stream(user_input): # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task await executor.execute(mock_request_context, mock_event_queue) @@ -95,7 +95,7 @@ async def mock_stream(user_input): # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task await executor.execute(mock_request_context, mock_event_queue) @@ -125,7 +125,7 @@ async def mock_stream(user_input): # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task await executor.execute(mock_request_context, mock_event_queue) @@ -156,7 +156,7 @@ async def mock_stream(user_input): mock_request_context.current_task = None with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task: - mock_new_task.return_value = MagicMock(id="new-task-id", contextId="new-context-id") + mock_new_task.return_value = MagicMock(id="new-task-id", context_id="new-context-id") await executor.execute(mock_request_context, mock_event_queue) @@ -180,7 +180,7 @@ async def test_execute_streaming_mode_handles_agent_exception( # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task with pytest.raises(ServerError): @@ -210,7 +210,7 @@ async def test_handle_agent_result_with_none_result(mock_strands_agent, mock_req # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task # Mock TaskUpdater @@ -235,7 +235,7 @@ async def test_handle_agent_result_with_result_but_no_message( # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task # Mock TaskUpdater diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py index fc76b5f1d..a3b47581c 100644 --- a/tests/strands/multiagent/a2a/test_server.py +++ b/tests/strands/multiagent/a2a/test_server.py @@ -87,8 +87,8 @@ def test_public_agent_card(mock_strands_agent): assert card.description == "A test agent for unit testing" assert card.url == "http://0.0.0.0:9000/" assert card.version == "0.0.1" - assert card.defaultInputModes == ["text"] - assert card.defaultOutputModes == ["text"] + assert card.default_input_modes == ["text"] + assert card.default_output_modes == ["text"] assert card.skills == [] assert card.capabilities == a2a_agent.capabilities From 8b1de4d4cc4f8adc5386bb1a134aabf96e698cdd Mon Sep 17 00:00:00 2001 From: Laith Al-Saadoon <9553966+theagenticguy@users.noreply.github.com> Date: Fri, 1 Aug 2025 09:23:25 -0500 Subject: [PATCH 018/104] fix: uses new a2a snake_case for lints to pass (#591) --- src/strands/models/anthropic.py | 2 +- src/strands/models/bedrock.py | 2 +- src/strands/session/file_session_manager.py | 3 ++- src/strands/session/s3_session_manager.py | 3 ++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 0d734b762..975fca3e9 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -414,7 +414,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9b36b4244..4ea1453a4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -631,7 +631,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index b32cb00e6..fec2f0761 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -23,6 +23,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): """File-based session manager for local filesystem storage. Creates the following filesystem structure for the session storage: + ```bash // └── session_/ ā”œā”€ā”€ session.json # Session metadata @@ -32,7 +33,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): └── messages/ ā”œā”€ā”€ message_.json └── message_.json - + ``` """ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any): diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 8f8423828..0cc0a68c1 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -24,6 +24,7 @@ class S3SessionManager(RepositorySessionManager, SessionRepository): """S3-based session manager for cloud storage. Creates the following filesystem structure for the session storage: + ```bash // └── session_/ ā”œā”€ā”€ session.json # Session metadata @@ -33,7 +34,7 @@ class S3SessionManager(RepositorySessionManager, SessionRepository): └── messages/ ā”œā”€ā”€ message_.json └── message_.json - + ``` """ def __init__( From c85464c45715a9d2ef3f9377f59f9e970ee81cf9 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 1 Aug 2025 10:37:17 -0400 Subject: [PATCH 019/104] =?UTF-8?q?fix(event=5Floop):=20raise=20dedicated?= =?UTF-8?q?=20exception=20when=20encountering=20max=20toke=E2=80=A6=20(#57?= =?UTF-8?q?6)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(event_loop): raise dedicated exception when encountering max tokens stop reason * fix: update integ tests * fix: rename exception message, add to exception, move earlier in cycle * Update tests_integ/test_max_tokens_reached.py Co-authored-by: Nick Clegg * Update tests_integ/test_max_tokens_reached.py Co-authored-by: Nick Clegg * linting --------- Co-authored-by: Nick Clegg --- src/strands/event_loop/event_loop.py | 26 ++++++++++- src/strands/types/exceptions.py | 21 +++++++++ tests/strands/event_loop/test_event_loop.py | 52 ++++++++++++++++++++- tests_integ/test_max_tokens_reached.py | 20 ++++++++ 4 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 tests_integ/test_max_tokens_reached.py diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index ffcb6a5c9..ae21d4c6d 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -28,7 +28,12 @@ from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools from ..types.content import Message -from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +from ..types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + MaxTokensReachedException, + ModelThrottledException, +) from ..types.streaming import Metrics, StopReason from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse from .streaming import stream_messages @@ -187,6 +192,22 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> raise e try: + if stop_reason == "max_tokens": + """ + Handle max_tokens limit reached by the model. + + When the model reaches its maximum token limit, this represents a potentially unrecoverable + state where the model's response was truncated. By default, Strands fails hard with an + MaxTokensReachedException to maintain consistency with other failure types. + """ + raise MaxTokensReachedException( + message=( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ), + incomplete_message=message, + ) # Add message in trace and mark the end of the stream messages trace stream_trace.add_message(message) stream_trace.end() @@ -231,7 +252,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # Don't yield or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise - except ContextWindowOverflowException as e: + except (ContextWindowOverflowException, MaxTokensReachedException) as e: + # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) raise e diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 4bd3fd88e..71ea28b9f 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -2,6 +2,8 @@ from typing import Any +from strands.types.content import Message + class EventLoopException(Exception): """Exception raised by the event loop.""" @@ -18,6 +20,25 @@ def __init__(self, original_exception: Exception, request_state: Any = None) -> super().__init__(str(original_exception)) +class MaxTokensReachedException(Exception): + """Exception raised when the model reaches its maximum token generation limit. + + This exception is raised when the model stops generating tokens because it has reached the maximum number of + tokens allowed for output generation. This can occur when the model's max_tokens parameter is set too low for + the complexity of the response, or when the model naturally reaches its configured output limit during generation. + """ + + def __init__(self, message: str, incomplete_message: Message): + """Initialize the exception with an error message and the incomplete message object. + + Args: + message: The error message describing the token limit issue + incomplete_message: The valid Message object with incomplete content due to token limits + """ + self.incomplete_message = incomplete_message + super().__init__(message) + + class ContextWindowOverflowException(Exception): """Exception raised when the context window is exceeded. diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 1ac2f8258..3886df8b9 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -19,7 +19,12 @@ ) from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +from strands.types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + MaxTokensReachedException, + ModelThrottledException, +) from tests.fixtures.mock_hook_provider import MockHookProvider @@ -556,6 +561,51 @@ async def test_event_loop_tracing_with_model_error( mock_tracer.end_span_with_error.assert_called_once_with(model_span, "Input too long", model.stream.side_effect) +@pytest.mark.asyncio +async def test_event_loop_cycle_max_tokens_exception( + agent, + model, + agenerator, + alist, +): + """Test that max_tokens stop reason raises MaxTokensReachedException.""" + + # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 + model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": {}, + }, + }, + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "max_tokens"}}, + ] + ) + + # Call event_loop_cycle, expecting it to raise MaxTokensReachedException + with pytest.raises(MaxTokensReachedException) as exc_info: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + # Verify the exception message contains the expected content + expected_message = ( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + assert str(exc_info.value) == expected_message + + # Verify that the message has not been appended to the messages array + assert len(agent.messages) == 1 + assert exc_info.value.incomplete_message not in agent.messages + + @patch("strands.event_loop.event_loop.get_tracer") @pytest.mark.asyncio async def test_event_loop_tracing_with_tool_execution( diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py new file mode 100644 index 000000000..d9c2817b3 --- /dev/null +++ b/tests_integ/test_max_tokens_reached.py @@ -0,0 +1,20 @@ +import pytest + +from strands import Agent, tool +from strands.models.bedrock import BedrockModel +from strands.types.exceptions import MaxTokensReachedException + + +@tool +def story_tool(story: str) -> str: + return story + + +def test_context_window_overflow(): + model = BedrockModel(max_tokens=100) + agent = Agent(model=model, tools=[story_tool]) + + with pytest.raises(MaxTokensReachedException): + agent("Tell me a story!") + + assert len(agent.messages) == 1 From 34d499aeea8ddb933c73711b1371704d4de8c9ba Mon Sep 17 00:00:00 2001 From: poshinchen Date: Tue, 5 Aug 2025 12:39:21 -0400 Subject: [PATCH 020/104] fix(telemetry): added mcp tracing context propagation (#569) --- src/strands/tools/mcp/mcp_client.py | 2 + src/strands/tools/mcp/mcp_instrumentation.py | 322 ++++++++++++ .../tools/mcp/test_mcp_instrumentation.py | 491 ++++++++++++++++++ 3 files changed, 815 insertions(+) create mode 100644 src/strands/tools/mcp/mcp_instrumentation.py create mode 100644 tests/strands/tools/mcp/test_mcp_instrumentation.py diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 8c21baa4a..c1aa96df3 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -29,6 +29,7 @@ from ...types.media import ImageFormat from ...types.tools import ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool +from .mcp_instrumentation import mcp_instrumentation from .mcp_types import MCPToolResult, MCPTransport logger = logging.getLogger(__name__) @@ -68,6 +69,7 @@ def __init__(self, transport_callable: Callable[[], MCPTransport]): Args: transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple """ + mcp_instrumentation() self._session_id = uuid.uuid4() self._log_debug_with_thread("initializing MCPClient connection") self._init_future: futures.Future[None] = futures.Future() # Main thread blocks until future completes diff --git a/src/strands/tools/mcp/mcp_instrumentation.py b/src/strands/tools/mcp/mcp_instrumentation.py new file mode 100644 index 000000000..338721db5 --- /dev/null +++ b/src/strands/tools/mcp/mcp_instrumentation.py @@ -0,0 +1,322 @@ +"""OpenTelemetry instrumentation for Model Context Protocol (MCP) tracing. + +Enables distributed tracing across MCP client-server boundaries by injecting +OpenTelemetry context into MCP request metadata (_meta field) and extracting +it on the server side, creating unified traces that span from agent calls +through MCP tool executions. + +Based on: https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-mcp +Related issue: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/246 +""" + +from contextlib import _AsyncGeneratorContextManager, asynccontextmanager +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Callable, Tuple + +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCMessage, JSONRPCRequest +from opentelemetry import context, propagate +from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper + + +@dataclass(slots=True, frozen=True) +class ItemWithContext: + """Wrapper for items that need to carry OpenTelemetry context. + + Used to preserve tracing context across async boundaries in MCP sessions, + ensuring that distributed traces remain connected even when messages are + processed asynchronously. + + Attributes: + item: The original item being wrapped + ctx: The OpenTelemetry context associated with the item + """ + + item: Any + ctx: context.Context + + +def mcp_instrumentation() -> None: + """Apply OpenTelemetry instrumentation patches to MCP components. + + This function instruments three key areas of MCP communication: + 1. Client-side: Injects tracing context into tool call requests + 2. Transport-level: Extracts context from incoming messages + 3. Session-level: Manages bidirectional context flow + + The patches enable distributed tracing by: + - Adding OpenTelemetry context to the _meta field of MCP requests + - Extracting and activating context on the server side + - Preserving context across async message processing boundaries + """ + + def patch_mcp_client(wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any) -> Any: + """Patch MCP client to inject OpenTelemetry context into tool calls. + + Intercepts outgoing MCP requests and injects the current OpenTelemetry + context into the request's _meta field for tools/call methods. This + enables server-side context extraction and trace continuation. + + Args: + wrapped: The original function being wrapped + instance: The instance the method is being called on + args: Positional arguments to the wrapped function + kwargs: Keyword arguments to the wrapped function + + Returns: + Result of the wrapped function call + """ + if len(args) < 1: + return wrapped(*args, **kwargs) + + request = args[0] + method = getattr(request.root, "method", None) + + if method != "tools/call": + return wrapped(*args, **kwargs) + + try: + if hasattr(request.root, "params") and request.root.params: + # Handle Pydantic models + if hasattr(request.root.params, "model_dump") and hasattr(request.root.params, "model_validate"): + params_dict = request.root.params.model_dump() + # Add _meta with tracing context + meta = params_dict.setdefault("_meta", {}) + propagate.get_global_textmap().inject(meta) + + # Recreate the Pydantic model with the updated data + # This preserves the original model type and avoids serialization warnings + params_class = type(request.root.params) + try: + request.root.params = params_class.model_validate(params_dict) + except Exception: + # Fallback to dict if model recreation fails + request.root.params = params_dict + + elif isinstance(request.root.params, dict): + # Handle dict params directly + meta = request.root.params.setdefault("_meta", {}) + propagate.get_global_textmap().inject(meta) + + return wrapped(*args, **kwargs) + + except Exception: + return wrapped(*args, **kwargs) + + def transport_wrapper() -> Callable[ + [Callable[..., Any], Any, Any, Any], _AsyncGeneratorContextManager[tuple[Any, Any]] + ]: + """Create a wrapper for MCP transport connections. + + Returns a context manager that wraps transport read/write streams + with context extraction capabilities. The wrapped reader will + automatically extract OpenTelemetry context from incoming messages. + + Returns: + An async context manager that yields wrapped transport streams + """ + + @asynccontextmanager + async def traced_method( + wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any + ) -> AsyncGenerator[Tuple[Any, Any], None]: + async with wrapped(*args, **kwargs) as result: + try: + read_stream, write_stream = result + except ValueError: + read_stream, write_stream, _ = result + yield TransportContextExtractingReader(read_stream), write_stream + + return traced_method + + def session_init_wrapper() -> Callable[[Any, Any, Tuple[Any, ...], dict[str, Any]], None]: + """Create a wrapper for MCP session initialization. + + Wraps session message streams to enable bidirectional context flow. + The reader extracts and activates context, while the writer preserves + context for async processing. + + Returns: + A function that wraps session initialization + """ + + def traced_method( + wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: dict[str, Any] + ) -> None: + wrapped(*args, **kwargs) + reader = getattr(instance, "_incoming_message_stream_reader", None) + writer = getattr(instance, "_incoming_message_stream_writer", None) + if reader and writer: + instance._incoming_message_stream_reader = SessionContextAttachingReader(reader) + instance._incoming_message_stream_writer = SessionContextSavingWriter(writer) + + return traced_method + + # Apply patches + wrap_function_wrapper("mcp.shared.session", "BaseSession.send_request", patch_mcp_client) + + register_post_import_hook( + lambda _: wrap_function_wrapper( + "mcp.server.streamable_http", "StreamableHTTPServerTransport.connect", transport_wrapper() + ), + "mcp.server.streamable_http", + ) + + register_post_import_hook( + lambda _: wrap_function_wrapper("mcp.server.session", "ServerSession.__init__", session_init_wrapper()), + "mcp.server.session", + ) + + +class TransportContextExtractingReader(ObjectProxy): + """A proxy reader that extracts OpenTelemetry context from MCP messages. + + Wraps an async message stream reader to automatically extract and activate + OpenTelemetry context from the _meta field of incoming MCP requests. This + enables server-side trace continuation from client-injected context. + + The reader handles both SessionMessage and JSONRPCMessage formats, and + supports both dict and Pydantic model parameter structures. + """ + + def __init__(self, wrapped: Any) -> None: + """Initialize the context-extracting reader. + + Args: + wrapped: The original async stream reader to wrap + """ + super().__init__(wrapped) + + async def __aenter__(self) -> Any: + """Enter the async context manager by delegating to the wrapped object.""" + return await self.__wrapped__.__aenter__() + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: + """Exit the async context manager by delegating to the wrapped object.""" + return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) + + async def __aiter__(self) -> AsyncGenerator[Any, None]: + """Iterate over messages, extracting and activating context as needed. + + For each incoming message, checks if it contains tracing context in + the _meta field. If found, extracts and activates the context for + the duration of message processing, then properly detaches it. + + Yields: + Messages from the wrapped stream, processed under the appropriate + OpenTelemetry context + """ + async for item in self.__wrapped__: + if isinstance(item, SessionMessage): + request = item.message.root + elif type(item) is JSONRPCMessage: + request = item.root + else: + yield item + continue + + if isinstance(request, JSONRPCRequest) and request.params: + # Handle both dict and Pydantic model params + if hasattr(request.params, "get"): + # Dict-like access + meta = request.params.get("_meta") + elif hasattr(request.params, "_meta"): + # Direct attribute access for Pydantic models + meta = getattr(request.params, "_meta", None) + else: + meta = None + + if meta: + extracted_context = propagate.extract(meta) + restore = context.attach(extracted_context) + try: + yield item + continue + finally: + context.detach(restore) + yield item + + +class SessionContextSavingWriter(ObjectProxy): + """A proxy writer that preserves OpenTelemetry context with outgoing items. + + Wraps an async message stream writer to capture the current OpenTelemetry + context and associate it with outgoing items. This enables context + preservation across async boundaries in MCP session processing. + """ + + def __init__(self, wrapped: Any) -> None: + """Initialize the context-saving writer. + + Args: + wrapped: The original async stream writer to wrap + """ + super().__init__(wrapped) + + async def __aenter__(self) -> Any: + """Enter the async context manager by delegating to the wrapped object.""" + return await self.__wrapped__.__aenter__() + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: + """Exit the async context manager by delegating to the wrapped object.""" + return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) + + async def send(self, item: Any) -> Any: + """Send an item while preserving the current OpenTelemetry context. + + Captures the current context and wraps the item with it, enabling + the receiving side to restore the appropriate tracing context. + + Args: + item: The item to send through the stream + + Returns: + Result of sending the wrapped item + """ + ctx = context.get_current() + return await self.__wrapped__.send(ItemWithContext(item, ctx)) + + +class SessionContextAttachingReader(ObjectProxy): + """A proxy reader that restores OpenTelemetry context from wrapped items. + + Wraps an async message stream reader to detect ItemWithContext instances + and restore their associated OpenTelemetry context during processing. + This completes the context preservation cycle started by SessionContextSavingWriter. + """ + + def __init__(self, wrapped: Any) -> None: + """Initialize the context-attaching reader. + + Args: + wrapped: The original async stream reader to wrap + """ + super().__init__(wrapped) + + async def __aenter__(self) -> Any: + """Enter the async context manager by delegating to the wrapped object.""" + return await self.__wrapped__.__aenter__() + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any: + """Exit the async context manager by delegating to the wrapped object.""" + return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback) + + async def __aiter__(self) -> AsyncGenerator[Any, None]: + """Iterate over items, restoring context for ItemWithContext instances. + + For items wrapped with context, temporarily activates the associated + OpenTelemetry context during processing, then properly detaches it. + Regular items are yielded without context modification. + + Yields: + Unwrapped items processed under their associated OpenTelemetry context + """ + async for item in self.__wrapped__: + if isinstance(item, ItemWithContext): + restore = context.attach(item.ctx) + try: + yield item.item + finally: + context.detach(restore) + else: + yield item diff --git a/tests/strands/tools/mcp/test_mcp_instrumentation.py b/tests/strands/tools/mcp/test_mcp_instrumentation.py new file mode 100644 index 000000000..61a485777 --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_instrumentation.py @@ -0,0 +1,491 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCMessage, JSONRPCRequest +from opentelemetry import context, propagate + +from strands.tools.mcp.mcp_instrumentation import ( + ItemWithContext, + SessionContextAttachingReader, + SessionContextSavingWriter, + TransportContextExtractingReader, + mcp_instrumentation, +) + + +class TestItemWithContext: + def test_item_with_context_creation(self): + """Test that ItemWithContext correctly stores item and context.""" + test_item = {"test": "data"} + test_context = context.get_current() + + wrapped = ItemWithContext(test_item, test_context) + + assert wrapped.item == test_item + assert wrapped.ctx == test_context + + +class TestTransportContextExtractingReader: + @pytest.fixture + def mock_wrapped_reader(self): + """Create a mock wrapped reader.""" + mock_reader = AsyncMock() + mock_reader.__aenter__ = AsyncMock(return_value=mock_reader) + mock_reader.__aexit__ = AsyncMock() + return mock_reader + + def test_init(self, mock_wrapped_reader): + """Test reader initialization.""" + reader = TransportContextExtractingReader(mock_wrapped_reader) + assert reader.__wrapped__ == mock_wrapped_reader + + @pytest.mark.asyncio + async def test_context_manager_methods(self, mock_wrapped_reader): + """Test async context manager methods delegate correctly.""" + reader = TransportContextExtractingReader(mock_wrapped_reader) + + await reader.__aenter__() + mock_wrapped_reader.__aenter__.assert_called_once() + + await reader.__aexit__(None, None, None) + mock_wrapped_reader.__aexit__.assert_called_once_with(None, None, None) + + @pytest.mark.asyncio + async def test_aiter_with_session_message_and_dict_meta(self, mock_wrapped_reader): + """Test context extraction from SessionMessage with dict params containing _meta.""" + # Create mock message with dict params containing _meta + mock_request = MagicMock(spec=JSONRPCRequest) + mock_request.params = {"_meta": {"traceparent": "test-trace-id"}, "other": "data"} + + mock_message = MagicMock() + mock_message.root = mock_request + + mock_session_message = MagicMock(spec=SessionMessage) + mock_session_message.message = mock_message + + async def async_iter(): + for item in [mock_session_message]: + yield item + + mock_wrapped_reader.__aiter__ = lambda self: async_iter() + + reader = TransportContextExtractingReader(mock_wrapped_reader) + + with ( + patch.object(propagate, "extract") as mock_extract, + patch.object(context, "attach") as mock_attach, + patch.object(context, "detach") as mock_detach, + ): + mock_context = MagicMock() + mock_extract.return_value = mock_context + mock_token = MagicMock() + mock_attach.return_value = mock_token + + items = [] + async for item in reader: + items.append(item) + + assert len(items) == 1 + assert items[0] == mock_session_message + + mock_extract.assert_called_once_with({"traceparent": "test-trace-id"}) + mock_attach.assert_called_once_with(mock_context) + mock_detach.assert_called_once_with(mock_token) + + @pytest.mark.asyncio + async def test_aiter_with_session_message_and_pydantic_meta(self, mock_wrapped_reader): + """Test context extraction from SessionMessage with Pydantic params having _meta attribute.""" + # Create mock message with Pydantic-style params + mock_request = MagicMock(spec=JSONRPCRequest) + + # Create a mock params object that doesn't have 'get' method but has '_meta' attribute + mock_params = MagicMock() + # Remove the get method to simulate Pydantic model behavior + del mock_params.get + mock_params._meta = {"traceparent": "test-trace-id"} + mock_request.params = mock_params + + mock_message = MagicMock() + mock_message.root = mock_request + + mock_session_message = MagicMock(spec=SessionMessage) + mock_session_message.message = mock_message + + async def async_iter(): + for item in [mock_session_message]: + yield item + + mock_wrapped_reader.__aiter__ = lambda self: async_iter() + + reader = TransportContextExtractingReader(mock_wrapped_reader) + + with ( + patch.object(propagate, "extract") as mock_extract, + patch.object(context, "attach") as mock_attach, + patch.object(context, "detach") as mock_detach, + ): + mock_context = MagicMock() + mock_extract.return_value = mock_context + mock_token = MagicMock() + mock_attach.return_value = mock_token + + items = [] + async for item in reader: + items.append(item) + + assert len(items) == 1 + assert items[0] == mock_session_message + + mock_extract.assert_called_once_with({"traceparent": "test-trace-id"}) + mock_attach.assert_called_once_with(mock_context) + mock_detach.assert_called_once_with(mock_token) + + @pytest.mark.asyncio + async def test_aiter_with_jsonrpc_message_no_meta(self, mock_wrapped_reader): + """Test handling JSONRPCMessage without _meta.""" + mock_request = MagicMock(spec=JSONRPCRequest) + mock_request.params = {"other": "data"} + + mock_message = MagicMock(spec=JSONRPCMessage) + mock_message.root = mock_request + + async def async_iter(): + for item in [mock_message]: + yield item + + mock_wrapped_reader.__aiter__ = lambda self: async_iter() + + reader = TransportContextExtractingReader(mock_wrapped_reader) + + items = [] + async for item in reader: + items.append(item) + + assert len(items) == 1 + assert items[0] == mock_message + + @pytest.mark.asyncio + async def test_aiter_with_non_message_item(self, mock_wrapped_reader): + """Test handling non-message items.""" + other_item = {"not": "a message"} + + async def async_iter(): + for item in [other_item]: + yield item + + mock_wrapped_reader.__aiter__ = lambda self: async_iter() + + reader = TransportContextExtractingReader(mock_wrapped_reader) + + items = [] + async for item in reader: + items.append(item) + + assert len(items) == 1 + assert items[0] == other_item + + +class TestSessionContextSavingWriter: + @pytest.fixture + def mock_wrapped_writer(self): + """Create a mock wrapped writer.""" + mock_writer = AsyncMock() + mock_writer.__aenter__ = AsyncMock(return_value=mock_writer) + mock_writer.__aexit__ = AsyncMock() + mock_writer.send = AsyncMock() + return mock_writer + + def test_init(self, mock_wrapped_writer): + """Test writer initialization.""" + writer = SessionContextSavingWriter(mock_wrapped_writer) + assert writer.__wrapped__ == mock_wrapped_writer + + @pytest.mark.asyncio + async def test_context_manager_methods(self, mock_wrapped_writer): + """Test async context manager methods delegate correctly.""" + writer = SessionContextSavingWriter(mock_wrapped_writer) + + await writer.__aenter__() + mock_wrapped_writer.__aenter__.assert_called_once() + + await writer.__aexit__(None, None, None) + mock_wrapped_writer.__aexit__.assert_called_once_with(None, None, None) + + @pytest.mark.asyncio + async def test_send_wraps_item_with_context(self, mock_wrapped_writer): + """Test that send wraps items with current context.""" + writer = SessionContextSavingWriter(mock_wrapped_writer) + test_item = {"test": "data"} + + with patch.object(context, "get_current") as mock_get_current: + mock_context = MagicMock() + mock_get_current.return_value = mock_context + + await writer.send(test_item) + + mock_get_current.assert_called_once() + mock_wrapped_writer.send.assert_called_once() + + # Verify the item was wrapped with context + sent_item = mock_wrapped_writer.send.call_args[0][0] + assert isinstance(sent_item, ItemWithContext) + assert sent_item.item == test_item + assert sent_item.ctx == mock_context + + +class TestSessionContextAttachingReader: + @pytest.fixture + def mock_wrapped_reader(self): + """Create a mock wrapped reader.""" + mock_reader = AsyncMock() + mock_reader.__aenter__ = AsyncMock(return_value=mock_reader) + mock_reader.__aexit__ = AsyncMock() + return mock_reader + + def test_init(self, mock_wrapped_reader): + """Test reader initialization.""" + reader = SessionContextAttachingReader(mock_wrapped_reader) + assert reader.__wrapped__ == mock_wrapped_reader + + @pytest.mark.asyncio + async def test_context_manager_methods(self, mock_wrapped_reader): + """Test async context manager methods delegate correctly.""" + reader = SessionContextAttachingReader(mock_wrapped_reader) + + await reader.__aenter__() + mock_wrapped_reader.__aenter__.assert_called_once() + + await reader.__aexit__(None, None, None) + mock_wrapped_reader.__aexit__.assert_called_once_with(None, None, None) + + @pytest.mark.asyncio + async def test_aiter_with_item_with_context(self, mock_wrapped_reader): + """Test context restoration from ItemWithContext.""" + test_item = {"test": "data"} + test_context = MagicMock() + wrapped_item = ItemWithContext(test_item, test_context) + + async def async_iter(): + for item in [wrapped_item]: + yield item + + mock_wrapped_reader.__aiter__ = lambda self: async_iter() + + reader = SessionContextAttachingReader(mock_wrapped_reader) + + with patch.object(context, "attach") as mock_attach, patch.object(context, "detach") as mock_detach: + mock_token = MagicMock() + mock_attach.return_value = mock_token + + items = [] + async for item in reader: + items.append(item) + + assert len(items) == 1 + assert items[0] == test_item + + mock_attach.assert_called_once_with(test_context) + mock_detach.assert_called_once_with(mock_token) + + @pytest.mark.asyncio + async def test_aiter_with_regular_item(self, mock_wrapped_reader): + """Test handling regular items without context.""" + regular_item = {"regular": "item"} + + async def async_iter(): + for item in [regular_item]: + yield item + + mock_wrapped_reader.__aiter__ = lambda self: async_iter() + + reader = SessionContextAttachingReader(mock_wrapped_reader) + + items = [] + async for item in reader: + items.append(item) + + assert len(items) == 1 + assert items[0] == regular_item + + +# Mock Pydantic-like class for testing +class MockPydanticParams: + """Mock class that behaves like a Pydantic model.""" + + def __init__(self, **data): + self._data = data + + def model_dump(self): + return self._data.copy() + + @classmethod + def model_validate(cls, data): + return cls(**data) + + def __getattr__(self, name): + return self._data.get(name) + + +class TestMCPInstrumentation: + def test_mcp_instrumentation_calls_wrap_function_wrapper(self): + """Test that mcp_instrumentation calls the expected wrapper functions.""" + with ( + patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap, + patch("strands.tools.mcp.mcp_instrumentation.register_post_import_hook") as mock_register, + ): + mcp_instrumentation() + + # Verify wrap_function_wrapper was called for client patching + mock_wrap.assert_called_once_with( + "mcp.shared.session", + "BaseSession.send_request", + mock_wrap.call_args_list[0][0][2], # The patch function + ) + + # Verify register_post_import_hook was called for transport and session wrappers + assert mock_register.call_count == 2 + + # Check that the registered hooks are for the expected modules + registered_modules = [call[0][1] for call in mock_register.call_args_list] + assert "mcp.server.streamable_http" in registered_modules + assert "mcp.server.session" in registered_modules + + def test_patch_mcp_client_injects_context_pydantic_model(self): + """Test that the client patch injects OpenTelemetry context into Pydantic models.""" + # Create a mock request with tools/call method and Pydantic params + mock_request = MagicMock() + mock_request.root.method = "tools/call" + + # Use our mock Pydantic-like class + mock_params = MockPydanticParams(existing="param") + mock_request.root.params = mock_params + + # Create the patch function + with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + mcp_instrumentation() + patch_function = mock_wrap.call_args_list[0][0][2] + + # Mock the wrapped function + mock_wrapped = MagicMock() + + with patch.object(propagate, "get_global_textmap") as mock_textmap: + mock_textmap_instance = MagicMock() + mock_textmap.return_value = mock_textmap_instance + + # Call the patch function + patch_function(mock_wrapped, None, [mock_request], {}) + + # Verify context was injected + mock_textmap_instance.inject.assert_called_once() + mock_wrapped.assert_called_once_with(mock_request) + + # Verify the params object is still a MockPydanticParams (or dict if fallback occurred) + assert hasattr(mock_request.root.params, "model_dump") or isinstance(mock_request.root.params, dict) + + def test_patch_mcp_client_injects_context_dict_params(self): + """Test that the client patch injects OpenTelemetry context into dict params.""" + # Create a mock request with tools/call method and dict params + mock_request = MagicMock() + mock_request.root.method = "tools/call" + mock_request.root.params = {"existing": "param"} + + # Create the patch function + with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + mcp_instrumentation() + patch_function = mock_wrap.call_args_list[0][0][2] + + # Mock the wrapped function + mock_wrapped = MagicMock() + + with patch.object(propagate, "get_global_textmap") as mock_textmap: + mock_textmap_instance = MagicMock() + mock_textmap.return_value = mock_textmap_instance + + # Call the patch function + patch_function(mock_wrapped, None, [mock_request], {}) + + # Verify context was injected + mock_textmap_instance.inject.assert_called_once() + mock_wrapped.assert_called_once_with(mock_request) + + # Verify _meta was added to the params dict + assert "_meta" in mock_request.root.params + + def test_patch_mcp_client_skips_non_tools_call(self): + """Test that the client patch skips non-tools/call methods.""" + mock_request = MagicMock() + mock_request.root.method = "other/method" + + with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + mcp_instrumentation() + patch_function = mock_wrap.call_args_list[0][0][2] + + mock_wrapped = MagicMock() + + with patch.object(propagate, "get_global_textmap") as mock_textmap: + mock_textmap_instance = MagicMock() + mock_textmap.return_value = mock_textmap_instance + + patch_function(mock_wrapped, None, [mock_request], {}) + + # Verify context injection was skipped + mock_textmap_instance.inject.assert_not_called() + mock_wrapped.assert_called_once_with(mock_request) + + def test_patch_mcp_client_handles_exception_gracefully(self): + """Test that the client patch handles exceptions gracefully.""" + # Create a mock request that will cause an exception + mock_request = MagicMock() + mock_request.root.method = "tools/call" + mock_request.root.params = MagicMock() + mock_request.root.params.model_dump.side_effect = Exception("Test exception") + + with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + mcp_instrumentation() + patch_function = mock_wrap.call_args_list[0][0][2] + + mock_wrapped = MagicMock() + + # Should not raise an exception, should call wrapped function normally + patch_function(mock_wrapped, None, [mock_request], {}) + mock_wrapped.assert_called_once_with(mock_request) + + def test_patch_mcp_client_pydantic_fallback_to_dict(self): + """Test that Pydantic model recreation falls back to dict on failure.""" + + # Create a Pydantic-like class that fails on model_validate + class FailingMockPydanticParams: + def __init__(self, **data): + self._data = data + + def model_dump(self): + return self._data.copy() + + def model_validate(self, data): + raise Exception("Reconstruction failed") + + # Create a mock request with failing Pydantic params + mock_request = MagicMock() + mock_request.root.method = "tools/call" + + failing_params = FailingMockPydanticParams(existing="param") + mock_request.root.params = failing_params + + with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + mcp_instrumentation() + patch_function = mock_wrap.call_args_list[0][0][2] + + mock_wrapped = MagicMock() + + with patch.object(propagate, "get_global_textmap") as mock_textmap: + mock_textmap_instance = MagicMock() + mock_textmap.return_value = mock_textmap_instance + + # Call the patch function + patch_function(mock_wrapped, None, [mock_request], {}) + + # Verify it fell back to dict + assert isinstance(mock_request.root.params, dict) + assert "_meta" in mock_request.root.params + mock_wrapped.assert_called_once_with(mock_request) From 09ca806adf2f7efa367c812514435eb0089dcd0a Mon Sep 17 00:00:00 2001 From: Vince Mi Date: Tue, 5 Aug 2025 10:32:30 -0700 Subject: [PATCH 021/104] Change max_tokens type to int to match Anthropic API (#588) Using a string causes the Anthropic API call to fail: ``` anthropic.BadRequestError: Error code: 400 - {'type': 'error', 'error': {'type': 'invalid_request_error', 'message': 'max_tokens: Input should be a valid integer'}} ``` --- src/strands/models/anthropic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 975fca3e9..29cb40d40 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -55,7 +55,7 @@ class AnthropicConfig(TypedDict, total=False): For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages. """ - max_tokens: Required[str] + max_tokens: Required[int] model_id: Required[str] params: Optional[dict[str, Any]] From bf24ebf4d479cedea3f74452d8309142232203f9 Mon Sep 17 00:00:00 2001 From: mehtarac Date: Wed, 6 Aug 2025 09:40:42 -0400 Subject: [PATCH 022/104] feat: Add additional intructions for contributors to find issues that are ready to be worked on (#595) --- CONTRIBUTING.md | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fa724cddc..add4825fd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -25,6 +25,17 @@ Please try to include as much information as you can. Details like these are inc * Anything unusual about your environment or deployment +## Finding contributions to work on +Looking at the existing issues is a great way to find something to contribute to. We label issues that are well-defined and ready for community contributions with the "ready for contribution" label. + +Check our [Ready for Contribution](../../issues?q=is%3Aissue%20state%3Aopen%20label%3A%22ready%20for%20contribution%22) issues for items you can work on. + +Before starting work on any issue: +1. Check if someone is already assigned or working on it +2. Comment on the issue to express your interest and ask any clarifying questions +3. Wait for maintainer confirmation before beginning significant work + + ## Development Environment This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as the build backend and [hatch](https://hatch.pypa.io/latest/) for development workflow management. @@ -70,7 +81,7 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as ### Pre-commit Hooks -We use [pre-commit](https://pre-commit.com/) to automatically run quality checks before each commit. The hook will run `hatch run format`, `hatch run lint`, `hatch run test`, and `hatch run cz check` on when you make a commit, ensuring code consistency. +We use [pre-commit](https://pre-commit.com/) to automatically run quality checks before each commit. The hook will run `hatch run format`, `hatch run lint`, `hatch run test`, and `hatch run cz check` when you make a commit, ensuring code consistency. The pre-commit hook is installed with: @@ -122,14 +133,6 @@ To send us a pull request, please: 8. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. -## Finding contributions to work on -Looking at the existing issues is a great way to find something to contribute to. - -You can check: -- Our known bugs list in [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for issues that need fixing -- Feature requests in [Feature Requests](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Aenhancement) for new functionality to implement - - ## Code of Conduct This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact From 297ec5cdfcd4b1e6e178429ef657911654e865b0 Mon Sep 17 00:00:00 2001 From: Jeremiah Date: Wed, 6 Aug 2025 09:54:49 -0400 Subject: [PATCH 023/104] feat(a2a): configurable request handler (#601) Co-authored-by: jer --- src/strands/multiagent/a2a/server.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py index fa7b6b887..35ea5b2e3 100644 --- a/src/strands/multiagent/a2a/server.py +++ b/src/strands/multiagent/a2a/server.py @@ -10,8 +10,9 @@ import uvicorn from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication +from a2a.server.events import QueueManager from a2a.server.request_handlers import DefaultRequestHandler -from a2a.server.tasks import InMemoryTaskStore +from a2a.server.tasks import InMemoryTaskStore, PushNotificationConfigStore, PushNotificationSender, TaskStore from a2a.types import AgentCapabilities, AgentCard, AgentSkill from fastapi import FastAPI from starlette.applications import Starlette @@ -36,6 +37,12 @@ def __init__( serve_at_root: bool = False, version: str = "0.0.1", skills: list[AgentSkill] | None = None, + # RequestHandler + task_store: TaskStore | None = None, + queue_manager: QueueManager | None = None, + push_config_store: PushNotificationConfigStore | None = None, + push_sender: PushNotificationSender | None = None, + ): """Initialize an A2A-compatible server from a Strands agent. @@ -52,6 +59,14 @@ def __init__( Defaults to False. version: The version of the agent. Defaults to "0.0.1". skills: The list of capabilities or functions the agent can perform. + task_store: Custom task store implementation for managing agent tasks. If None, + uses InMemoryTaskStore. + queue_manager: Custom queue manager for handling message queues. If None, + no queue management is used. + push_config_store: Custom store for push notification configurations. If None, + no push notification configuration is used. + push_sender: Custom push notification sender implementation. If None, + no push notifications are sent. """ self.host = host self.port = port @@ -77,7 +92,10 @@ def __init__( self.capabilities = AgentCapabilities(streaming=True) self.request_handler = DefaultRequestHandler( agent_executor=StrandsA2AExecutor(self.strands_agent), - task_store=InMemoryTaskStore(), + task_store=task_store or InMemoryTaskStore(), + queue_manager=queue_manager, + push_config_store=push_config_store, + push_sender=push_sender, ) self._agent_skills = skills logger.info("Strands' integration with A2A is experimental. Be aware of frequent breaking changes.") From ec5304c39809b99b1d29ed03ce7ae40536575e95 Mon Sep 17 00:00:00 2001 From: Jeremiah Date: Wed, 6 Aug 2025 12:48:13 -0400 Subject: [PATCH 024/104] chore(a2a): update host per AppSec recommendation (#619) Co-authored-by: jer --- src/strands/multiagent/a2a/server.py | 5 ++--- tests/strands/multiagent/a2a/test_server.py | 10 +++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py index 35ea5b2e3..bbfbc824d 100644 --- a/src/strands/multiagent/a2a/server.py +++ b/src/strands/multiagent/a2a/server.py @@ -31,7 +31,7 @@ def __init__( agent: SAAgent, *, # AgentCard - host: str = "0.0.0.0", + host: str = "127.0.0.1", port: int = 9000, http_url: str | None = None, serve_at_root: bool = False, @@ -42,13 +42,12 @@ def __init__( queue_manager: QueueManager | None = None, push_config_store: PushNotificationConfigStore | None = None, push_sender: PushNotificationSender | None = None, - ): """Initialize an A2A-compatible server from a Strands agent. Args: agent: The Strands Agent to wrap with A2A compatibility. - host: The hostname or IP address to bind the A2A server to. Defaults to "0.0.0.0". + host: The hostname or IP address to bind the A2A server to. Defaults to "127.0.0.1". port: The port to bind the A2A server to. Defaults to 9000. http_url: The public HTTP URL where this agent will be accessible. If provided, this overrides the generated URL from host/port and enables automatic diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py index a3b47581c..00dd164b5 100644 --- a/tests/strands/multiagent/a2a/test_server.py +++ b/tests/strands/multiagent/a2a/test_server.py @@ -22,9 +22,9 @@ def test_a2a_agent_initialization(mock_strands_agent): assert a2a_agent.strands_agent == mock_strands_agent assert a2a_agent.name == "Test Agent" assert a2a_agent.description == "A test agent for unit testing" - assert a2a_agent.host == "0.0.0.0" + assert a2a_agent.host == "127.0.0.1" assert a2a_agent.port == 9000 - assert a2a_agent.http_url == "http://0.0.0.0:9000/" + assert a2a_agent.http_url == "http://127.0.0.1:9000/" assert a2a_agent.version == "0.0.1" assert isinstance(a2a_agent.capabilities, AgentCapabilities) assert len(a2a_agent.agent_skills) == 1 @@ -85,7 +85,7 @@ def test_public_agent_card(mock_strands_agent): assert isinstance(card, AgentCard) assert card.name == "Test Agent" assert card.description == "A test agent for unit testing" - assert card.url == "http://0.0.0.0:9000/" + assert card.url == "http://127.0.0.1:9000/" assert card.version == "0.0.1" assert card.default_input_modes == ["text"] assert card.default_output_modes == ["text"] @@ -448,7 +448,7 @@ def test_serve_with_starlette(mock_run, mock_strands_agent): mock_run.assert_called_once() args, kwargs = mock_run.call_args assert isinstance(args[0], Starlette) - assert kwargs["host"] == "0.0.0.0" + assert kwargs["host"] == "127.0.0.1" assert kwargs["port"] == 9000 @@ -462,7 +462,7 @@ def test_serve_with_fastapi(mock_run, mock_strands_agent): mock_run.assert_called_once() args, kwargs = mock_run.call_args assert isinstance(args[0], FastAPI) - assert kwargs["host"] == "0.0.0.0" + assert kwargs["host"] == "127.0.0.1" assert kwargs["port"] == 9000 From 29b21278f5816ffa01dbb555bb6ff192ae105d59 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 8 Aug 2025 10:43:34 -0400 Subject: [PATCH 025/104] fix(event_loop): ensure tool_use content blocks are valid after max_tokens to prevent unrecoverable state (#607) --- .../_recover_message_on_max_tokens_reached.py | 71 +++++ src/strands/event_loop/event_loop.py | 32 ++- src/strands/types/exceptions.py | 6 +- tests/strands/event_loop/test_event_loop.py | 55 ++-- ...t_recover_message_on_max_tokens_reached.py | 269 ++++++++++++++++++ tests_integ/test_max_tokens_reached.py | 32 ++- 6 files changed, 420 insertions(+), 45 deletions(-) create mode 100644 src/strands/event_loop/_recover_message_on_max_tokens_reached.py create mode 100644 tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py new file mode 100644 index 000000000..ab6fb4abe --- /dev/null +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -0,0 +1,71 @@ +"""Message recovery utilities for handling max token limit scenarios. + +This module provides functionality to recover and clean up incomplete messages that occur +when model responses are truncated due to maximum token limits being reached. It specifically +handles cases where tool use blocks are incomplete or malformed due to truncation. +""" + +import logging + +from ..types.content import ContentBlock, Message +from ..types.tools import ToolUse + +logger = logging.getLogger(__name__) + + +def recover_message_on_max_tokens_reached(message: Message) -> Message: + """Recover and clean up messages when max token limits are reached. + + When a model response is truncated due to maximum token limits, all tool use blocks + should be replaced with informative error messages since they may be incomplete or + unreliable. This function inspects the message content and: + + 1. Identifies all tool use blocks (regardless of validity) + 2. Replaces all tool uses with informative error messages + 3. Preserves all non-tool content blocks (text, images, etc.) + 4. Returns a cleaned message suitable for conversation history + + This recovery mechanism ensures that the conversation can continue gracefully even when + model responses are truncated, providing clear feedback about what happened and preventing + potentially incomplete or corrupted tool executions. + + Args: + message: The potentially incomplete message from the model that was truncated + due to max token limits. + + Returns: + A cleaned Message with all tool uses replaced by explanatory text content. + The returned message maintains the same role as the input message. + + Example: + If a message contains any tool use (complete or incomplete): + ``` + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}} + ``` + + It will be replaced with: + ``` + {"text": "The selected tool calculator's tool use was incomplete due to maximum token limits being reached."} + ``` + """ + logger.info("handling max_tokens stop reason - replacing all tool uses with error messages") + + valid_content: list[ContentBlock] = [] + for content in message["content"] or []: + tool_use: ToolUse | None = content.get("toolUse") + if not tool_use: + valid_content.append(content) + continue + + # Replace all tool uses with error messages when max_tokens is reached + display_name = tool_use.get("name") or "" + logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) + + valid_content.append( + { + "text": f"The selected tool {display_name}'s tool use was incomplete due " + f"to maximum token limits being reached." + } + ) + + return {"content": valid_content, "role": message["role"]} diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index ae21d4c6d..b36f73155 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -36,6 +36,7 @@ ) from ..types.streaming import Metrics, StopReason from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse +from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached from .streaming import stream_messages if TYPE_CHECKING: @@ -156,6 +157,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) ) + if stop_reason == "max_tokens": + message = recover_message_on_max_tokens_reached(message) + if model_invoke_span: tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) break # Success! Break out of retry loop @@ -192,6 +196,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> raise e try: + # Add message in trace and mark the end of the stream messages trace + stream_trace.add_message(message) + stream_trace.end() + + # Add the response message to the conversation + agent.messages.append(message) + agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) + yield {"callback": {"message": message}} + + # Update metrics + agent.event_loop_metrics.update_usage(usage) + agent.event_loop_metrics.update_metrics(metrics) + if stop_reason == "max_tokens": """ Handle max_tokens limit reached by the model. @@ -205,21 +222,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> "Agent has reached an unrecoverable state due to max_tokens limit. " "For more information see: " "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" - ), - incomplete_message=message, + ) ) - # Add message in trace and mark the end of the stream messages trace - stream_trace.add_message(message) - stream_trace.end() - - # Add the response message to the conversation - agent.messages.append(message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) - yield {"callback": {"message": message}} - - # Update metrics - agent.event_loop_metrics.update_usage(usage) - agent.event_loop_metrics.update_metrics(metrics) # If the model is requesting to use tools if stop_reason == "tool_use": diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 71ea28b9f..90f2b8d7f 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -2,8 +2,6 @@ from typing import Any -from strands.types.content import Message - class EventLoopException(Exception): """Exception raised by the event loop.""" @@ -28,14 +26,12 @@ class MaxTokensReachedException(Exception): the complexity of the response, or when the model naturally reaches its configured output limit during generation. """ - def __init__(self, message: str, incomplete_message: Message): + def __init__(self, message: str): """Initialize the exception with an error message and the incomplete message object. Args: message: The error message describing the token limit issue - incomplete_message: The valid Message object with incomplete content due to token limits """ - self.incomplete_message = incomplete_message super().__init__(message) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 3886df8b9..191ab51ba 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -305,8 +305,10 @@ async def test_event_loop_cycle_text_response_error( await alist(stream) +@patch("strands.event_loop.event_loop.recover_message_on_max_tokens_reached") @pytest.mark.asyncio async def test_event_loop_cycle_tool_result( + mock_recover_message, agent, model, system_prompt, @@ -339,6 +341,9 @@ async def test_event_loop_cycle_tool_result( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify that recover_message_on_max_tokens_reached was NOT called for tool_use stop reason + mock_recover_message.assert_not_called() + model.stream.assert_called_with( [ {"role": "user", "content": [{"text": "Hello"}]}, @@ -568,25 +573,35 @@ async def test_event_loop_cycle_max_tokens_exception( agenerator, alist, ): - """Test that max_tokens stop reason raises MaxTokensReachedException.""" + """Test that max_tokens stop reason calls _recover_message_on_max_tokens_reached then MaxTokensReachedException.""" - # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 - model.stream.return_value = agenerator( - [ - { - "contentBlockStart": { - "start": { - "toolUse": {}, + model.stream.side_effect = [ + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "asdf", + "input": {}, # empty + }, + }, }, }, - }, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "max_tokens"}}, - ] - ) + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "max_tokens"}}, + ] + ), + ] # Call event_loop_cycle, expecting it to raise MaxTokensReachedException - with pytest.raises(MaxTokensReachedException) as exc_info: + expected_message = ( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + with pytest.raises(MaxTokensReachedException, match=expected_message): stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -594,16 +609,8 @@ async def test_event_loop_cycle_max_tokens_exception( await alist(stream) # Verify the exception message contains the expected content - expected_message = ( - "Agent has reached an unrecoverable state due to max_tokens limit. " - "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" - ) - assert str(exc_info.value) == expected_message - - # Verify that the message has not been appended to the messages array - assert len(agent.messages) == 1 - assert exc_info.value.incomplete_message not in agent.messages + assert len(agent.messages) == 2 + assert "tool use was incomplete due" in agent.messages[1]["content"][0]["text"] @patch("strands.event_loop.event_loop.get_tracer") diff --git a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py new file mode 100644 index 000000000..402e90966 --- /dev/null +++ b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py @@ -0,0 +1,269 @@ +"""Tests for token limit recovery utility.""" + +from strands.event_loop._recover_message_on_max_tokens_reached import ( + recover_message_on_max_tokens_reached, +) +from strands.types.content import Message + + +def test_recover_message_on_max_tokens_reached_with_incomplete_tool_use(): + """Test recovery when incomplete tool use is present in the message.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 2 + + # First content block should be preserved + assert result["content"][0] == {"text": "I'll help you with that."} + + # Second content block should be replaced with error message + assert "text" in result["content"][1] + assert "calculator" in result["content"][1]["text"] + assert "incomplete due to maximum token limits" in result["content"][1]["text"] + + +def test_recover_message_on_max_tokens_reached_with_missing_tool_name(): + """Test recovery when tool use has no name.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 1 + + # Content should be replaced with error message using + assert "text" in result["content"][0] + assert "" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_missing_input(): + """Test recovery when tool use has no input.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "toolUseId": "123"}}, # Missing input + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 1 + + # Content should be replaced with error message + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_missing_tool_use_id(): + """Test recovery when tool use has no toolUseId.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}}}, # Missing toolUseId + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 1 + + # Content should be replaced with error message + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_valid_tool_use(): + """Test that even valid tool uses are replaced with error messages.""" + complete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}, # Valid + ], + } + + result = recover_message_on_max_tokens_reached(complete_message) + + # Should replace even valid tool uses with error messages + assert result["role"] == "assistant" + assert len(result["content"]) == 2 + assert result["content"][0] == {"text": "I'll help you with that."} + + # Valid tool use should also be replaced with error message + assert "text" in result["content"][1] + assert "calculator" in result["content"][1]["text"] + assert "incomplete due to maximum token limits" in result["content"][1]["text"] + + +def test_recover_message_on_max_tokens_reached_with_empty_content(): + """Test handling of message with empty content.""" + empty_message: Message = {"role": "assistant", "content": []} + + result = recover_message_on_max_tokens_reached(empty_message) + + # Should return message with empty content preserved + assert result["role"] == "assistant" + assert result["content"] == [] + + +def test_recover_message_on_max_tokens_reached_with_none_content(): + """Test handling of message with None content.""" + none_content_message: Message = {"role": "assistant", "content": None} + + result = recover_message_on_max_tokens_reached(none_content_message) + + # Should return message with empty content + assert result["role"] == "assistant" + assert result["content"] == [] + + +def test_recover_message_on_max_tokens_reached_with_mixed_content(): + """Test recovery with mix of valid content and incomplete tool use.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "Let me calculate this for you."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Incomplete + {"text": "And then I'll explain the result."}, + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First and third content blocks should be preserved + assert result["content"][0] == {"text": "Let me calculate this for you."} + assert result["content"][2] == {"text": "And then I'll explain the result."} + + # Second content block should be replaced with error message + assert "text" in result["content"][1] + assert "calculator" in result["content"][1]["text"] + assert "incomplete due to maximum token limits" in result["content"][1]["text"] + + +def test_recover_message_on_max_tokens_reached_preserves_non_tool_content(): + """Test that non-tool content is preserved as-is.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "Here's some text."}, + {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}}, + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Incomplete + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First two content blocks should be preserved exactly + assert result["content"][0] == {"text": "Here's some text."} + assert result["content"][1] == {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}} + + # Third content block should be replaced with error message + assert "text" in result["content"][2] + assert "" in result["content"][2]["text"] + assert "incomplete due to maximum token limits" in result["content"][2]["text"] + + +def test_recover_message_on_max_tokens_reached_multiple_incomplete_tools(): + """Test recovery with multiple incomplete tool uses.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {}}}, # Missing toolUseId + {"text": "Some text in between."}, + {"toolUse": {"name": "", "input": {}, "toolUseId": "456"}}, # Missing name + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First tool use should be replaced + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + # Text content should be preserved + assert result["content"][1] == {"text": "Some text in between."} + + # Second tool use should be replaced with + assert "text" in result["content"][2] + assert "" in result["content"][2]["text"] + assert "incomplete due to maximum token limits" in result["content"][2]["text"] + + +def test_recover_message_on_max_tokens_reached_preserves_user_role(): + """Test that the function preserves the original message role.""" + incomplete_message: Message = { + "role": "user", + "content": [ + {"toolUse": {"name": "calculator", "input": {}}}, # Missing toolUseId + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Should preserve the original role + assert result["role"] == "user" + assert len(result["content"]) == 1 + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_content_without_tool_use(): + """Test handling of content blocks that don't have toolUse key.""" + message: Message = { + "role": "assistant", + "content": [ + {"text": "Regular text content."}, + {"someOtherKey": "someValue"}, # Content without toolUse + {"toolUse": {"name": "calculator"}}, # Incomplete tool use + ], + } + + result = recover_message_on_max_tokens_reached(message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First two content blocks should be preserved + assert result["content"][0] == {"text": "Regular text content."} + assert result["content"][1] == {"someOtherKey": "someValue"} + + # Third content block should be replaced with error message + assert "text" in result["content"][2] + assert "calculator" in result["content"][2]["text"] + assert "incomplete due to maximum token limits" in result["content"][2]["text"] diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index d9c2817b3..bf5668349 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -1,20 +1,48 @@ +import logging + import pytest +from src.strands.agent import AgentResult from strands import Agent, tool from strands.models.bedrock import BedrockModel from strands.types.exceptions import MaxTokensReachedException +logger = logging.getLogger(__name__) + @tool def story_tool(story: str) -> str: + """ + Tool that writes a story that is minimum 50,000 lines long. + """ return story -def test_context_window_overflow(): +def test_max_tokens_reached(): + """Test that MaxTokensReachedException is raised but the agent can still rerun on the second pass""" model = BedrockModel(max_tokens=100) agent = Agent(model=model, tools=[story_tool]) + # This should raise an exception with pytest.raises(MaxTokensReachedException): agent("Tell me a story!") - assert len(agent.messages) == 1 + # Validate that at least one message contains the incomplete tool use error message + expected_text = "tool use was incomplete due to maximum token limits being reached" + all_text_content = [ + content_block["text"] + for message in agent.messages + for content_block in message.get("content", []) + if "text" in content_block + ] + + assert any(expected_text in text for text in all_text_content), ( + f"Expected to find message containing '{expected_text}' in agent messages" + ) + + # Remove tools from agent and re-run with a generic question + agent.tool_registry.registry = {} + agent.tool_registry.tool_config = {} + + result: AgentResult = agent("What is 3+3") + assert result.stop_reason == "end_turn" From adac26f15930fe2fc6754f5f9ddeab2ff9698463 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 8 Aug 2025 10:44:39 -0400 Subject: [PATCH 026/104] fix(structured_output): do not modify conversation_history when prompt is passed (#628) --- src/strands/agent/agent.py | 20 +++++----- tests/strands/agent/test_agent.py | 52 +++++++++++++++++++++++++ tests/strands/agent/test_agent_hooks.py | 10 ++--- 3 files changed, 67 insertions(+), 15 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 111509e3a..2022142c6 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -403,8 +403,8 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T: """This method allows you to get structured output from the agent. - If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. - If you don't pass in a prompt, it will use only the conversation history to respond. + If you pass in a prompt, it will be used temporarily without adding it to the conversation history. + If you don't pass in a prompt, it will use only the existing conversation history to respond. For smaller models, you may want to use the optional prompt to add additional instructions to explicitly instruct the model to output the structured data. @@ -412,7 +412,7 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l Args: output_model: The output model (a JSON schema written as a Pydantic BaseModel) that the agent will use when responding. - prompt: The prompt to use for the agent. + prompt: The prompt to use for the agent (will not be added to conversation history). Raises: ValueError: If no conversation history or prompt is provided. @@ -430,8 +430,8 @@ async def structured_output_async( ) -> T: """This method allows you to get structured output from the agent. - If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. - If you don't pass in a prompt, it will use only the conversation history to respond. + If you pass in a prompt, it will be used temporarily without adding it to the conversation history. + If you don't pass in a prompt, it will use only the existing conversation history to respond. For smaller models, you may want to use the optional prompt to add additional instructions to explicitly instruct the model to output the structured data. @@ -439,7 +439,7 @@ async def structured_output_async( Args: output_model: The output model (a JSON schema written as a Pydantic BaseModel) that the agent will use when responding. - prompt: The prompt to use for the agent. + prompt: The prompt to use for the agent (will not be added to conversation history). Raises: ValueError: If no conversation history or prompt is provided. @@ -450,12 +450,14 @@ async def structured_output_async( if not self.messages and not prompt: raise ValueError("No conversation history or prompt provided") - # add the prompt as the last message + # Create temporary messages array if prompt is provided if prompt: content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt - self._append_message({"role": "user", "content": content}) + temp_messages = self.messages + [{"role": "user", "content": content}] + else: + temp_messages = self.messages - events = self.model.structured_output(output_model, self.messages, system_prompt=self.system_prompt) + events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) async for event in events: if "callback" in event: self.callback_handler(**cast(dict, event["callback"])) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 4e310dace..c27243dfe 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -984,10 +984,17 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + # Store initial message count + initial_message_count = len(agent.messages) + tru_result = agent.structured_output(type(user), prompt) exp_result = user assert tru_result == exp_result + # Verify conversation history is not polluted + assert len(agent.messages) == initial_message_count + + # Verify the model was called with temporary messages array agent.model.structured_output.assert_called_once_with( type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt ) @@ -1008,10 +1015,17 @@ def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, a }, ] + # Store initial message count + initial_message_count = len(agent.messages) + tru_result = agent.structured_output(type(user), prompt) exp_result = user assert tru_result == exp_result + # Verify conversation history is not polluted + assert len(agent.messages) == initial_message_count + + # Verify the model was called with temporary messages array agent.model.structured_output.assert_called_once_with( type(user), [{"role": "user", "content": prompt}], system_prompt=system_prompt ) @@ -1023,10 +1037,41 @@ async def test_agent_structured_output_in_async_context(agent, user, agenerator) prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + # Store initial message count + initial_message_count = len(agent.messages) + tru_result = await agent.structured_output_async(type(user), prompt) exp_result = user assert tru_result == exp_result + # Verify conversation history is not polluted + assert len(agent.messages) == initial_message_count + + +def test_agent_structured_output_without_prompt(agent, system_prompt, user, agenerator): + """Test that structured_output works with existing conversation history and no new prompt.""" + agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) + + # Add some existing messages to the agent + existing_messages = [ + {"role": "user", "content": [{"text": "Jane Doe is 30 years old"}]}, + {"role": "assistant", "content": [{"text": "I understand."}]}, + ] + agent.messages.extend(existing_messages) + + initial_message_count = len(agent.messages) + + tru_result = agent.structured_output(type(user)) # No prompt provided + exp_result = user + assert tru_result == exp_result + + # Verify conversation history is unchanged + assert len(agent.messages) == initial_message_count + assert agent.messages == existing_messages + + # Verify the model was called with existing messages only + agent.model.structured_output.assert_called_once_with(type(user), existing_messages, system_prompt=system_prompt) + @pytest.mark.asyncio async def test_agent_structured_output_async(agent, system_prompt, user, agenerator): @@ -1034,10 +1079,17 @@ async def test_agent_structured_output_async(agent, system_prompt, user, agenera prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + # Store initial message count + initial_message_count = len(agent.messages) + tru_result = agent.structured_output(type(user), prompt) exp_result = user assert tru_result == exp_result + # Verify conversation history is not polluted + assert len(agent.messages) == initial_message_count + + # Verify the model was called with temporary messages array agent.model.structured_output.assert_called_once_with( type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt ) diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index cd89fbc7a..9ab008ca2 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -267,13 +267,12 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): length, events = hook_provider.get_events() - assert length == 3 + assert length == 2 assert next(events) == BeforeInvocationEvent(agent=agent) - assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0]) assert next(events) == AfterInvocationEvent(agent=agent) - assert len(agent.messages) == 1 + assert len(agent.messages) == 0 # no new messages added @pytest.mark.asyncio @@ -285,10 +284,9 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a length, events = hook_provider.get_events() - assert length == 3 + assert length == 2 assert next(events) == BeforeInvocationEvent(agent=agent) - assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0]) assert next(events) == AfterInvocationEvent(agent=agent) - assert len(agent.messages) == 1 + assert len(agent.messages) == 0 # no new messages added From 99963b64c261431c6f10c31853a2dcc667a9ebbb Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 11 Aug 2025 17:11:27 +0200 Subject: [PATCH 027/104] feature(graph): Allow cyclic graphs (#497) --- .gitignore | 3 +- src/strands/multiagent/graph.py | 304 ++++++++++--- tests/strands/multiagent/test_graph.py | 579 ++++++++++++++++++++++++- 3 files changed, 805 insertions(+), 81 deletions(-) diff --git a/.gitignore b/.gitignore index cb34b9150..c27d1d902 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ __pycache__* *.bak .vscode dist -repl_state \ No newline at end of file +repl_state +.kiro \ No newline at end of file diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index cbba0fecf..9aee260b1 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -1,31 +1,33 @@ -"""Directed Acyclic Graph (DAG) Multi-Agent Pattern Implementation. +"""Directed Graph Multi-Agent Pattern Implementation. -This module provides a deterministic DAG-based agent orchestration system where +This module provides a deterministic graph-based agent orchestration system where agents or MultiAgentBase instances (like Swarm or Graph) are nodes in a graph, executed according to edge dependencies, with output from one node passed as input to connected nodes. Key Features: - Agents and MultiAgentBase instances (Swarm, Graph, etc.) as graph nodes -- Deterministic execution order based on DAG structure +- Deterministic execution based on dependency resolution - Output propagation along edges -- Topological sort for execution ordering +- Support for cyclic graphs (feedback loops) - Clear dependency management - Supports nested graphs (Graph as a node in another Graph) """ import asyncio +import copy import logging import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Any, Callable, Tuple +from typing import Any, Callable, Optional, Tuple from opentelemetry import trace as trace_api from ..agent import Agent +from ..agent.state import AgentState from ..telemetry import get_tracer -from ..types.content import ContentBlock +from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status @@ -54,6 +56,7 @@ class GraphState: completed_nodes: set["GraphNode"] = field(default_factory=set) failed_nodes: set["GraphNode"] = field(default_factory=set) execution_order: list["GraphNode"] = field(default_factory=list) + start_time: float = field(default_factory=time.time) # Results results: dict[str, NodeResult] = field(default_factory=dict) @@ -69,6 +72,27 @@ class GraphState: edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list) entry_points: list["GraphNode"] = field(default_factory=list) + def should_continue( + self, + max_node_executions: Optional[int], + execution_timeout: Optional[float], + ) -> Tuple[bool, str]: + """Check if the graph should continue execution. + + Returns: (should_continue, reason) + """ + # Check node execution limit (only if set) + if max_node_executions is not None and len(self.execution_order) >= max_node_executions: + return False, f"Max node executions reached: {max_node_executions}" + + # Check timeout (only if set) + if execution_timeout is not None: + elapsed = time.time() - self.start_time + if elapsed > execution_timeout: + return False, f"Execution timed out: {execution_timeout}s" + + return True, "Continuing" + @dataclass class GraphResult(MultiAgentResult): @@ -117,6 +141,33 @@ class GraphNode: execution_status: Status = Status.PENDING result: NodeResult | None = None execution_time: int = 0 + _initial_messages: Messages = field(default_factory=list, init=False) + _initial_state: AgentState = field(default_factory=AgentState, init=False) + + def __post_init__(self) -> None: + """Capture initial executor state after initialization.""" + # Deep copy the initial messages and state to preserve them + if hasattr(self.executor, "messages"): + self._initial_messages = copy.deepcopy(self.executor.messages) + + if hasattr(self.executor, "state") and hasattr(self.executor.state, "get"): + self._initial_state = AgentState(self.executor.state.get()) + + def reset_executor_state(self) -> None: + """Reset GraphNode executor state to initial state when graph was created. + + This is useful when nodes are executed multiple times and need to start + fresh on each execution, providing stateless behavior. + """ + if hasattr(self.executor, "messages"): + self.executor.messages = copy.deepcopy(self._initial_messages) + + if hasattr(self.executor, "state"): + self.executor.state = AgentState(self._initial_state.get()) + + # Reset execution status + self.execution_status = Status.PENDING + self.result = None def __hash__(self) -> int: """Return hash for GraphNode based on node_id.""" @@ -164,6 +215,12 @@ def __init__(self) -> None: self.edges: set[GraphEdge] = set() self.entry_points: set[GraphNode] = set() + # Configuration options + self._max_node_executions: Optional[int] = None + self._execution_timeout: Optional[float] = None + self._node_timeout: Optional[float] = None + self._reset_on_revisit: bool = False + def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode: """Add an Agent or MultiAgentBase instance as a node to the graph.""" _validate_node_executor(executor, self.nodes) @@ -213,8 +270,48 @@ def set_entry_point(self, node_id: str) -> "GraphBuilder": self.entry_points.add(self.nodes[node_id]) return self + def reset_on_revisit(self, enabled: bool = True) -> "GraphBuilder": + """Control whether nodes reset their state when revisited. + + When enabled, nodes will reset their messages and state to initial values + each time they are revisited (re-executed). This is useful for stateless + behavior where nodes should start fresh on each revisit. + + Args: + enabled: Whether to reset node state when revisited (default: True) + """ + self._reset_on_revisit = enabled + return self + + def set_max_node_executions(self, max_executions: int) -> "GraphBuilder": + """Set maximum number of node executions allowed. + + Args: + max_executions: Maximum total node executions (None for no limit) + """ + self._max_node_executions = max_executions + return self + + def set_execution_timeout(self, timeout: float) -> "GraphBuilder": + """Set total execution timeout. + + Args: + timeout: Total execution timeout in seconds (None for no limit) + """ + self._execution_timeout = timeout + return self + + def set_node_timeout(self, timeout: float) -> "GraphBuilder": + """Set individual node execution timeout. + + Args: + timeout: Individual node timeout in seconds (None for no limit) + """ + self._node_timeout = timeout + return self + def build(self) -> "Graph": - """Build and validate the graph.""" + """Build and validate the graph with configured settings.""" if not self.nodes: raise ValueError("Graph must contain at least one node") @@ -230,44 +327,53 @@ def build(self) -> "Graph": # Validate entry points and check for cycles self._validate_graph() - return Graph(nodes=self.nodes.copy(), edges=self.edges.copy(), entry_points=self.entry_points.copy()) + return Graph( + nodes=self.nodes.copy(), + edges=self.edges.copy(), + entry_points=self.entry_points.copy(), + max_node_executions=self._max_node_executions, + execution_timeout=self._execution_timeout, + node_timeout=self._node_timeout, + reset_on_revisit=self._reset_on_revisit, + ) def _validate_graph(self) -> None: - """Validate graph structure and detect cycles.""" + """Validate graph structure.""" # Validate entry points exist entry_point_ids = {node.node_id for node in self.entry_points} invalid_entries = entry_point_ids - set(self.nodes.keys()) if invalid_entries: raise ValueError(f"Entry points not found in nodes: {invalid_entries}") - # Check for cycles using DFS with color coding - WHITE, GRAY, BLACK = 0, 1, 2 - colors = {node_id: WHITE for node_id in self.nodes} - - def has_cycle_from(node_id: str) -> bool: - if colors[node_id] == GRAY: - return True # Back edge found - cycle detected - if colors[node_id] == BLACK: - return False - - colors[node_id] = GRAY - # Check all outgoing edges for cycles - for edge in self.edges: - if edge.from_node.node_id == node_id and has_cycle_from(edge.to_node.node_id): - return True - colors[node_id] = BLACK - return False - - # Check for cycles from each unvisited node - if any(colors[node_id] == WHITE and has_cycle_from(node_id) for node_id in self.nodes): - raise ValueError("Graph contains cycles - must be a directed acyclic graph") + # Warn about potential infinite loops if no execution limits are set + if self._max_node_executions is None and self._execution_timeout is None: + logger.warning("Graph without execution limits may run indefinitely if cycles exist") class Graph(MultiAgentBase): - """Directed Acyclic Graph multi-agent orchestration.""" + """Directed Graph multi-agent orchestration with configurable revisit behavior.""" - def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_points: set[GraphNode]) -> None: - """Initialize Graph.""" + def __init__( + self, + nodes: dict[str, GraphNode], + edges: set[GraphEdge], + entry_points: set[GraphNode], + max_node_executions: Optional[int] = None, + execution_timeout: Optional[float] = None, + node_timeout: Optional[float] = None, + reset_on_revisit: bool = False, + ) -> None: + """Initialize Graph with execution limits and reset behavior. + + Args: + nodes: Dictionary of node_id to GraphNode + edges: Set of GraphEdge objects + entry_points: Set of GraphNode objects that are entry points + max_node_executions: Maximum total node executions (default: None - no limit) + execution_timeout: Total execution timeout in seconds (default: None - no limit) + node_timeout: Individual node timeout in seconds (default: None - no limit) + reset_on_revisit: Whether to reset node state when revisited (default: False) + """ super().__init__() # Validate nodes for duplicate instances @@ -276,6 +382,10 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi self.nodes = nodes self.edges = edges self.entry_points = entry_points + self.max_node_executions = max_node_executions + self.execution_timeout = execution_timeout + self.node_timeout = node_timeout + self.reset_on_revisit = reset_on_revisit self.state = GraphState() self.tracer = get_tracer() @@ -294,20 +404,34 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> G logger.debug("task=<%s> | starting graph execution", task) # Initialize state + start_time = time.time() self.state = GraphState( status=Status.EXECUTING, task=task, total_nodes=len(self.nodes), edges=[(edge.from_node, edge.to_node) for edge in self.edges], entry_points=list(self.entry_points), + start_time=start_time, ) - start_time = time.time() span = self.tracer.start_multiagent_span(task, "graph") with trace_api.use_span(span, end_on_exit=True): try: + logger.debug( + "max_node_executions=<%s>, execution_timeout=<%s>s, node_timeout=<%s>s | graph execution config", + self.max_node_executions or "None", + self.execution_timeout or "None", + self.node_timeout or "None", + ) + await self._execute_graph() - self.state.status = Status.COMPLETED + + # Set final status based on execution results + if self.state.failed_nodes: + self.state.status = Status.FAILED + elif self.state.status == Status.EXECUTING: # Only set to COMPLETED if still executing and no failures + self.state.status = Status.COMPLETED + logger.debug("status=<%s> | graph execution completed", self.state.status) except Exception: @@ -335,6 +459,16 @@ async def _execute_graph(self) -> None: ready_nodes = list(self.entry_points) while ready_nodes: + # Check execution limits before continuing + should_continue, reason = self.state.should_continue( + max_node_executions=self.max_node_executions, + execution_timeout=self.execution_timeout, + ) + if not should_continue: + self.state.status = Status.FAILED + logger.debug("reason=<%s> | stopping execution", reason) + return # Let the top-level exception handler deal with it + current_batch = ready_nodes.copy() ready_nodes.clear() @@ -386,7 +520,14 @@ def _is_node_ready_with_conditions(self, node: GraphNode) -> bool: return False async def _execute_node(self, node: GraphNode) -> None: - """Execute a single node with error handling.""" + """Execute a single node with error handling and timeout protection.""" + # Reset the node's state if reset_on_revisit is enabled and it's being revisited + if self.reset_on_revisit and node in self.state.completed_nodes: + logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) + node.reset_executor_state() + # Remove from completed nodes since we're re-executing it + self.state.completed_nodes.remove(node) + node.execution_status = Status.EXECUTING logger.debug("node_id=<%s> | executing node", node.node_id) @@ -395,42 +536,65 @@ async def _execute_node(self, node: GraphNode) -> None: # Build node input from satisfied dependencies node_input = self._build_node_input(node) - # Execute based on node type and create unified NodeResult - if isinstance(node.executor, MultiAgentBase): - multi_agent_result = await node.executor.invoke_async(node_input) - - # Create NodeResult with MultiAgentResult directly - node_result = NodeResult( - result=multi_agent_result, # type is MultiAgentResult - execution_time=multi_agent_result.execution_time, - status=Status.COMPLETED, - accumulated_usage=multi_agent_result.accumulated_usage, - accumulated_metrics=multi_agent_result.accumulated_metrics, - execution_count=multi_agent_result.execution_count, - ) + # Execute with timeout protection (only if node_timeout is set) + try: + # Execute based on node type and create unified NodeResult + if isinstance(node.executor, MultiAgentBase): + if self.node_timeout is not None: + multi_agent_result = await asyncio.wait_for( + node.executor.invoke_async(node_input), + timeout=self.node_timeout, + ) + else: + multi_agent_result = await node.executor.invoke_async(node_input) + + # Create NodeResult with MultiAgentResult directly + node_result = NodeResult( + result=multi_agent_result, # type is MultiAgentResult + execution_time=multi_agent_result.execution_time, + status=Status.COMPLETED, + accumulated_usage=multi_agent_result.accumulated_usage, + accumulated_metrics=multi_agent_result.accumulated_metrics, + execution_count=multi_agent_result.execution_count, + ) - elif isinstance(node.executor, Agent): - agent_response = await node.executor.invoke_async(node_input) - - # Extract metrics from agent response - usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) - metrics = Metrics(latencyMs=0) - if hasattr(agent_response, "metrics") and agent_response.metrics: - if hasattr(agent_response.metrics, "accumulated_usage"): - usage = agent_response.metrics.accumulated_usage - if hasattr(agent_response.metrics, "accumulated_metrics"): - metrics = agent_response.metrics.accumulated_metrics - - node_result = NodeResult( - result=agent_response, # type is AgentResult - execution_time=round((time.time() - start_time) * 1000), - status=Status.COMPLETED, - accumulated_usage=usage, - accumulated_metrics=metrics, - execution_count=1, + elif isinstance(node.executor, Agent): + if self.node_timeout is not None: + agent_response = await asyncio.wait_for( + node.executor.invoke_async(node_input), + timeout=self.node_timeout, + ) + else: + agent_response = await node.executor.invoke_async(node_input) + + # Extract metrics from agent response + usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) + metrics = Metrics(latencyMs=0) + if hasattr(agent_response, "metrics") and agent_response.metrics: + if hasattr(agent_response.metrics, "accumulated_usage"): + usage = agent_response.metrics.accumulated_usage + if hasattr(agent_response.metrics, "accumulated_metrics"): + metrics = agent_response.metrics.accumulated_metrics + + node_result = NodeResult( + result=agent_response, # type is AgentResult + execution_time=round((time.time() - start_time) * 1000), + status=Status.COMPLETED, + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=1, + ) + else: + raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") + + except asyncio.TimeoutError: + timeout_msg = f"Node '{node.node_id}' execution timed out after {self.node_timeout}s" + logger.exception( + "node=<%s>, timeout=<%s>s | node execution timed out after timeout", + node.node_id, + self.node_timeout, ) - else: - raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") + raise Exception(timeout_msg) from None # Mark as completed node.execution_status = Status.COMPLETED diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index cb74f515c..c60361da8 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1,8 +1,11 @@ +import asyncio +import time from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from strands.agent import Agent, AgentResult +from strands.agent.state import AgentState from strands.hooks import AgentInitializedEvent from strands.hooks.registry import HookProvider, HookRegistry from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult @@ -251,7 +254,8 @@ class UnsupportedExecutor: builder.add_node(UnsupportedExecutor(), "unsupported_node") graph = builder.build() - with pytest.raises(ValueError, match="Node 'unsupported_node' of type.*is not supported"): + # Execute the graph - should raise ValueError due to unsupported node type + with pytest.raises(ValueError, match="Node 'unsupported_node' of type .* is not supported"): await graph.invoke_async("test task") mock_strands_tracer.start_multiagent_span.assert_called() @@ -285,12 +289,10 @@ async def mock_invoke_failure(*args, **kwargs): graph = builder.build() + # Execute the graph - should raise Exception due to failing agent with pytest.raises(Exception, match="Simulated failure"): await graph.invoke_async("Test error handling") - assert graph.state.status == Status.FAILED - assert any(node.node_id == "fail_node" for node in graph.state.failed_nodes) - assert len(graph.state.completed_nodes) == 0 mock_strands_tracer.start_multiagent_span.assert_called() mock_use_span.assert_called_once() @@ -314,6 +316,91 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span): mock_use_span.assert_called_once() +@pytest.mark.asyncio +async def test_cyclic_graph_execution(mock_strands_tracer, mock_use_span): + """Test execution of a graph with cycles.""" + # Create mock agents with state tracking + agent_a = create_mock_agent("agent_a", "Agent A response") + agent_b = create_mock_agent("agent_b", "Agent B response") + agent_c = create_mock_agent("agent_c", "Agent C response") + + # Add state to agents to track execution + agent_a.state = AgentState() + agent_b.state = AgentState() + agent_c.state = AgentState() + + # Create a spy to track reset calls + reset_spy = MagicMock() + + # Create a graph with a cycle: A -> B -> C -> A + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_node(agent_c, "c") + builder.add_edge("a", "b") + builder.add_edge("b", "c") + builder.add_edge("c", "a") # Creates cycle + builder.set_entry_point("a") + builder.reset_on_revisit() # Enable state reset on revisit + + # Patch the reset_executor_state method to track calls + original_reset = GraphNode.reset_executor_state + + def spy_reset(self): + reset_spy(self.node_id) + original_reset(self) + + with patch.object(GraphNode, "reset_executor_state", spy_reset): + graph = builder.build() + + # Set a maximum iteration limit to prevent infinite loops + # but ensure we go through the cycle at least twice + # This value is used in the LimitedGraph class below + + # Execute the graph with a task that will cause it to cycle + result = await graph.invoke_async("Test cyclic graph execution") + + # Verify that the graph executed successfully + assert result.status == Status.COMPLETED + + # Verify that each agent was called at least once + agent_a.invoke_async.assert_called() + agent_b.invoke_async.assert_called() + agent_c.invoke_async.assert_called() + + # Verify that the execution order includes all nodes + assert len(result.execution_order) >= 3 + assert any(node.node_id == "a" for node in result.execution_order) + assert any(node.node_id == "b" for node in result.execution_order) + assert any(node.node_id == "c" for node in result.execution_order) + + # Verify that node state was reset during cyclic execution + # If we have more than 3 nodes in execution_order, at least one node was revisited + if len(result.execution_order) > 3: + # Check that reset_executor_state was called for revisited nodes + reset_spy.assert_called() + + # Count occurrences of each node in execution order + node_counts = {} + for node in result.execution_order: + node_counts[node.node_id] = node_counts.get(node.node_id, 0) + 1 + + # At least one node should appear multiple times + assert any(count > 1 for count in node_counts.values()), "No node was revisited in the cycle" + + # For each node that appears multiple times, verify reset was called + for node_id, count in node_counts.items(): + if count > 1: + # Check that reset was called at least (count-1) times for this node + reset_calls = sum(1 for call in reset_spy.call_args_list if call[0][0] == node_id) + assert reset_calls >= count - 1, ( + f"Node {node_id} appeared {count} times but reset was called {reset_calls} times" + ) + + # Verify all nodes were completed + assert result.completed_nodes == 3 + + def test_graph_builder_validation(): """Test GraphBuilder validation and error handling.""" # Test empty graph validation @@ -343,7 +430,11 @@ def test_graph_builder_validation(): node2 = GraphNode("node2", duplicate_agent) # Same agent instance nodes = {"node1": node1, "node2": node2} with pytest.raises(ValueError, match="Duplicate node instance detected"): - Graph(nodes=nodes, edges=set(), entry_points=set()) + Graph( + nodes=nodes, + edges=set(), + entry_points=set(), + ) # Test edge validation with non-existent nodes builder = GraphBuilder() @@ -368,7 +459,7 @@ def test_graph_builder_validation(): with pytest.raises(ValueError, match="Entry points not found in nodes"): builder.build() - # Test cycle detection + # Test cycle detection (should be forbidden by default) builder = GraphBuilder() builder.add_node(agent1, "a") builder.add_node(agent2, "b") @@ -378,8 +469,9 @@ def test_graph_builder_validation(): builder.add_edge("c", "a") # Creates cycle builder.set_entry_point("a") - with pytest.raises(ValueError, match="Graph contains cycles"): - builder.build() + # Should succeed - cycles are now allowed by default + graph = builder.build() + assert any(node.node_id == "a" for node in graph.entry_points) # Test auto-detection of entry points builder = GraphBuilder() @@ -400,6 +492,259 @@ def test_graph_builder_validation(): with pytest.raises(ValueError, match="No entry points found - all nodes have dependencies"): builder.build() + # Test custom execution limits and reset_on_revisit + builder = GraphBuilder() + builder.add_node(agent1, "test_node") + graph = ( + builder.set_max_node_executions(10) + .set_execution_timeout(300.0) + .set_node_timeout(60.0) + .reset_on_revisit() + .build() + ) + assert graph.max_node_executions == 10 + assert graph.execution_timeout == 300.0 + assert graph.node_timeout == 60.0 + assert graph.reset_on_revisit is True + + # Test default execution limits and reset_on_revisit (None and False) + builder = GraphBuilder() + builder.add_node(agent1, "test_node") + graph = builder.build() + assert graph.max_node_executions is None + assert graph.execution_timeout is None + assert graph.node_timeout is None + assert graph.reset_on_revisit is False + + +@pytest.mark.asyncio +async def test_graph_execution_limits(mock_strands_tracer, mock_use_span): + """Test graph execution limits (max_node_executions and execution_timeout).""" + # Test with a simple linear graph first to verify limits work + agent_a = create_mock_agent("agent_a", "Response A") + agent_b = create_mock_agent("agent_b", "Response B") + agent_c = create_mock_agent("agent_c", "Response C") + + # Create a linear graph: a -> b -> c + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_node(agent_c, "c") + builder.add_edge("a", "b") + builder.add_edge("b", "c") + builder.set_entry_point("a") + + # Test with no limits (backward compatibility) - should complete normally + graph = builder.build() # No limits specified + result = await graph.invoke_async("Test execution") + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 3 # All 3 nodes should execute + + # Test with limit that allows completion + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_node(agent_c, "c") + builder.add_edge("a", "b") + builder.add_edge("b", "c") + builder.set_entry_point("a") + graph = builder.set_max_node_executions(5).set_execution_timeout(900.0).set_node_timeout(300.0).build() + result = await graph.invoke_async("Test execution") + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 3 # All 3 nodes should execute + + # Test with limit that prevents full completion + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_node(agent_c, "c") + builder.add_edge("a", "b") + builder.add_edge("b", "c") + builder.set_entry_point("a") + graph = builder.set_max_node_executions(2).set_execution_timeout(900.0).set_node_timeout(300.0).build() + result = await graph.invoke_async("Test execution limit") + assert result.status == Status.FAILED # Should fail due to limit + assert len(result.execution_order) == 2 # Should stop at 2 executions + + # Test execution timeout by manipulating start time (like Swarm does) + timeout_agent_a = create_mock_agent("timeout_agent_a", "Response A") + timeout_agent_b = create_mock_agent("timeout_agent_b", "Response B") + + # Create a cyclic graph that would run indefinitely + builder = GraphBuilder() + builder.add_node(timeout_agent_a, "a") + builder.add_node(timeout_agent_b, "b") + builder.add_edge("a", "b") + builder.add_edge("b", "a") # Creates cycle + builder.set_entry_point("a") + + # Enable reset_on_revisit so the cycle can continue + graph = builder.reset_on_revisit(True).set_execution_timeout(5.0).set_max_node_executions(100).build() + + # Manipulate the start time to simulate timeout (like Swarm does) + result = await graph.invoke_async("Test execution timeout") + # Manually set start time to simulate timeout condition + graph.state.start_time = time.time() - 10 # Set start time to 10 seconds ago + + # Check the timeout logic directly + should_continue, reason = graph.state.should_continue(max_node_executions=100, execution_timeout=5.0) + assert should_continue is False + assert "Execution timed out" in reason + + # builder = GraphBuilder() + # builder.add_node(slow_agent, "slow") + # graph = (builder.set_max_node_executions(1000) # High limit to avoid hitting this + # .set_execution_timeout(0.05) # Very short execution timeout + # .set_node_timeout(300.0) + # .build()) + + # result = await graph.invoke_async("Test timeout") + # assert result.status == Status.FAILED # Should fail due to timeout + + mock_strands_tracer.start_multiagent_span.assert_called() + mock_use_span.assert_called() + + +@pytest.mark.asyncio +async def test_graph_node_timeout(mock_strands_tracer, mock_use_span): + """Test individual node timeout functionality.""" + + # Create a mock agent that takes longer than the node timeout + timeout_agent = create_mock_agent("timeout_agent", "Should timeout") + + async def timeout_invoke(*args, **kwargs): + await asyncio.sleep(0.2) # Longer than node timeout + return timeout_agent.return_value + + timeout_agent.invoke_async = AsyncMock(side_effect=timeout_invoke) + + builder = GraphBuilder() + builder.add_node(timeout_agent, "timeout_node") + + # Test with no timeout (backward compatibility) - should complete normally + graph = builder.build() # No timeout specified + result = await graph.invoke_async("Test no timeout") + assert result.status == Status.COMPLETED + assert result.completed_nodes == 1 + + # Test with very short node timeout - should raise timeout exception + builder = GraphBuilder() + builder.add_node(timeout_agent, "timeout_node") + graph = builder.set_max_node_executions(50).set_execution_timeout(900.0).set_node_timeout(0.1).build() + + # Execute the graph - should raise Exception due to timeout + with pytest.raises(Exception, match="Node 'timeout_node' execution timed out after 0.1s"): + await graph.invoke_async("Test node timeout") + + mock_strands_tracer.start_multiagent_span.assert_called() + mock_use_span.assert_called() + + +@pytest.mark.asyncio +async def test_backward_compatibility_no_limits(): + """Test that graphs with no limits specified work exactly as before.""" + # Create simple agents + agent_a = create_mock_agent("agent_a", "Response A") + agent_b = create_mock_agent("agent_b", "Response B") + + # Create a simple linear graph + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_edge("a", "b") + builder.set_entry_point("a") + + # Build without specifying any limits - should work exactly as before + graph = builder.build() + + # Verify the limits are None (no limits) + assert graph.max_node_executions is None + assert graph.execution_timeout is None + assert graph.node_timeout is None + + # Execute the graph - should complete normally + result = await graph.invoke_async("Test backward compatibility") + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 2 # Both nodes should execute + + +@pytest.mark.asyncio +async def test_node_reset_executor_state(): + """Test that GraphNode.reset_executor_state properly resets node state.""" + # Create a mock agent with state + agent = create_mock_agent("test_agent", "Test response") + agent.state = AgentState() + agent.state.set("test_key", "test_value") + agent.messages = [{"role": "system", "content": "Initial system message"}] + + # Create a GraphNode with this agent + node = GraphNode("test_node", agent) + + # Verify initial state is captured during initialization + assert len(node._initial_messages) == 1 + assert node._initial_messages[0]["role"] == "system" + assert node._initial_messages[0]["content"] == "Initial system message" + + # Modify agent state and messages after initialization + agent.state.set("new_key", "new_value") + agent.messages.append({"role": "user", "content": "New message"}) + + # Also modify execution status and result + node.execution_status = Status.COMPLETED + node.result = NodeResult( + result="test result", + execution_time=100, + status=Status.COMPLETED, + accumulated_usage={"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + accumulated_metrics={"latencyMs": 100}, + execution_count=1, + ) + + # Verify state was modified + assert len(agent.messages) == 2 + assert agent.state.get("new_key") == "new_value" + assert node.execution_status == Status.COMPLETED + assert node.result is not None + + # Reset the executor state + node.reset_executor_state() + + # Verify messages were reset to initial values + assert len(agent.messages) == 1 + assert agent.messages[0]["role"] == "system" + assert agent.messages[0]["content"] == "Initial system message" + + # Verify agent state was reset + # The test_key should be gone since it wasn't in the initial state + assert agent.state.get("new_key") is None + + # Verify execution status is reset + assert node.execution_status == Status.PENDING + assert node.result is None + + # Test with MultiAgentBase executor + multi_agent = create_mock_multi_agent("multi_agent") + multi_agent_node = GraphNode("multi_node", multi_agent) + + # Since MultiAgentBase doesn't have messages or state attributes, + # reset_executor_state should not fail + multi_agent_node.execution_status = Status.COMPLETED + multi_agent_node.result = NodeResult( + result="test result", + execution_time=100, + status=Status.COMPLETED, + accumulated_usage={}, + accumulated_metrics={}, + execution_count=1, + ) + + # Reset should work without errors + multi_agent_node.reset_executor_state() + + # Verify execution status is reset + assert multi_agent_node.execution_status == Status.PENDING + assert multi_agent_node.result is None + def test_graph_dataclasses_and_enums(): """Test dataclass initialization, properties, and enum behavior.""" @@ -417,6 +762,7 @@ def test_graph_dataclasses_and_enums(): assert state.task == "" assert state.accumulated_usage == {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0} assert state.execution_count == 0 + assert state.start_time > 0 # Should be set by default factory # Test GraphState with custom values state = GraphState(status=Status.EXECUTING, task="custom task", total_nodes=5, execution_count=3) @@ -540,9 +886,222 @@ def register_hooks(self, registry, **kwargs): # Test with session manager in Graph constructor node_with_session = GraphNode("node_with_session", agent_with_session) with pytest.raises(ValueError, match="Session persistence is not supported for Graph agents yet"): - Graph(nodes={"node_with_session": node_with_session}, edges=set(), entry_points=set()) + Graph( + nodes={"node_with_session": node_with_session}, + edges=set(), + entry_points=set(), + ) # Test with callbacks in Graph constructor node_with_hooks = GraphNode("node_with_hooks", agent_with_hooks) with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"): - Graph(nodes={"node_with_hooks": node_with_hooks}, edges=set(), entry_points=set()) + Graph( + nodes={"node_with_hooks": node_with_hooks}, + edges=set(), + entry_points=set(), + ) + + +@pytest.mark.asyncio +async def test_controlled_cyclic_execution(): + """Test cyclic graph execution with controlled cycle count to verify state reset.""" + + # Create a stateful agent that tracks its own execution count + class StatefulAgent(Agent): + def __init__(self, name): + super().__init__() + self.name = name + self.state = AgentState() + self.state.set("execution_count", 0) + self.messages = [] + self._session_manager = None + self.hooks = HookRegistry() + + async def invoke_async(self, input_data): + # Increment execution count in state + count = self.state.get("execution_count") or 0 + self.state.set("execution_count", count + 1) + + return AgentResult( + message={"role": "assistant", "content": [{"text": f"{self.name} response (execution {count + 1})"}]}, + stop_reason="end_turn", + state={}, + metrics=Mock( + accumulated_usage={"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + accumulated_metrics={"latencyMs": 100.0}, + ), + ) + + # Create agents + agent_a = StatefulAgent("agent_a") + agent_b = StatefulAgent("agent_b") + + # Create a graph with a simple cycle: A -> B -> A + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_edge("a", "b") + builder.add_edge("b", "a") # Creates cycle + builder.set_entry_point("a") + builder.reset_on_revisit() # Enable state reset on revisit + + # Build with limited max_node_executions to prevent infinite loop + graph = builder.set_max_node_executions(3).build() + + # Execute the graph + result = await graph.invoke_async("Test controlled cyclic execution") + + # With a 2-node cycle and limit of 3, we should see either completion or failure + # The exact behavior depends on how the cycle detection works + if result.status == Status.COMPLETED: + # If it completed, verify it executed some nodes + assert len(result.execution_order) >= 2 + assert result.execution_order[0].node_id == "a" + elif result.status == Status.FAILED: + # If it failed due to limits, verify it hit the limit + assert len(result.execution_order) == 3 # Should stop at exactly 3 executions + assert result.execution_order[0].node_id == "a" + else: + # Should be either completed or failed + raise AssertionError(f"Unexpected status: {result.status}") + + # Most importantly, verify that state was reset properly between executions + # The state.execution_count should be set for both agents after execution + assert agent_a.state.get("execution_count") >= 1 # Node A executed at least once + assert agent_b.state.get("execution_count") >= 1 # Node B executed at least once + + +def test_reset_on_revisit_backward_compatibility(): + """Test that reset_on_revisit provides backward compatibility by default.""" + agent1 = create_mock_agent("agent1") + agent2 = create_mock_agent("agent2") + + # Test default behavior - reset_on_revisit is False by default + builder = GraphBuilder() + builder.add_node(agent1, "a") + builder.add_node(agent2, "b") + builder.add_edge("a", "b") + builder.set_entry_point("a") + + graph = builder.build() + assert graph.reset_on_revisit is False + + # Test reset_on_revisit with True + builder = GraphBuilder() + builder.add_node(agent1, "a") + builder.add_node(agent2, "b") + builder.add_edge("a", "b") + builder.set_entry_point("a") + builder.reset_on_revisit(True) + + graph = builder.build() + assert graph.reset_on_revisit is True + + # Test reset_on_revisit with False explicitly + builder = GraphBuilder() + builder.add_node(agent1, "a") + builder.add_node(agent2, "b") + builder.add_edge("a", "b") + builder.set_entry_point("a") + builder.reset_on_revisit(False) + + graph = builder.build() + assert graph.reset_on_revisit is False + + +def test_reset_on_revisit_method_chaining(): + """Test that reset_on_revisit method returns GraphBuilder for chaining.""" + agent1 = create_mock_agent("agent1") + + builder = GraphBuilder() + result = builder.reset_on_revisit() + + # Verify method chaining works + assert result is builder + assert builder._reset_on_revisit is True + + # Test full method chaining + builder.add_node(agent1, "test_node") + builder.set_max_node_executions(10) + graph = builder.build() + + assert graph.reset_on_revisit is True + assert graph.max_node_executions == 10 + + +@pytest.mark.asyncio +async def test_linear_graph_behavior(): + """Test that linear graph behavior works correctly.""" + agent_a = create_mock_agent("agent_a", "Response A") + agent_b = create_mock_agent("agent_b", "Response B") + + # Create linear graph + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_edge("a", "b") + builder.set_entry_point("a") + + graph = builder.build() + assert graph.reset_on_revisit is False + + # Execute should work normally + result = await graph.invoke_async("Test linear execution") + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 2 + assert result.execution_order[0].node_id == "a" + assert result.execution_order[1].node_id == "b" + + # Verify agents were called once each (no state reset) + agent_a.invoke_async.assert_called_once() + agent_b.invoke_async.assert_called_once() + + +@pytest.mark.asyncio +async def test_state_reset_only_with_cycles_enabled(): + """Test that state reset only happens when cycles are enabled.""" + # Create a mock agent that tracks state modifications + agent = create_mock_agent("test_agent", "Test response") + agent.state = AgentState() + agent.messages = [{"role": "system", "content": "Initial message"}] + + # Create GraphNode + node = GraphNode("test_node", agent) + + # Simulate agent being in completed_nodes (as if revisited) + from strands.multiagent.graph import GraphState + + state = GraphState() + state.completed_nodes.add(node) + + # Create graph with cycles disabled (default) + builder = GraphBuilder() + builder.add_node(agent, "test_node") + graph = builder.build() + + # Mock the _execute_node method to test conditional reset logic + import unittest.mock + + with unittest.mock.patch.object(node, "reset_executor_state") as mock_reset: + # Simulate the conditional logic from _execute_node + if graph.reset_on_revisit and node in state.completed_nodes: + node.reset_executor_state() + state.completed_nodes.remove(node) + + # With reset_on_revisit disabled, reset should not be called + mock_reset.assert_not_called() + + # Now test with reset_on_revisit enabled + builder = GraphBuilder() + builder.add_node(agent, "test_node") + builder.reset_on_revisit() + graph = builder.build() + + with unittest.mock.patch.object(node, "reset_executor_state") as mock_reset: + # Simulate the conditional logic from _execute_node + if graph.reset_on_revisit and node in state.completed_nodes: + node.reset_executor_state() + state.completed_nodes.remove(node) + + # With reset_on_revisit enabled, reset should be called + mock_reset.assert_called_once() From 72709cf16d40b985d05ecf2ddb2081fbe28d1aa2 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Mon, 11 Aug 2025 14:35:57 -0400 Subject: [PATCH 028/104] chore: request to include code snippet section (#654) --- .github/ISSUE_TEMPLATE/bug_report.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 3c357173c..b3898b7f7 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -61,9 +61,10 @@ body: label: Steps to Reproduce description: Detailed steps to reproduce the behavior placeholder: | - 1. Install Strands using... - 2. Run the command... - 3. See error... + 1. Code Snippet (Minimal reproducible example) + 2. Install Strands using... + 3. Run the command... + 4. See error... validations: required: true - type: textarea From 8434409a1f85816c6ec42756c79eb05b0914d6d1 Mon Sep 17 00:00:00 2001 From: fhwilton55 <81768750+fhwilton55@users.noreply.github.com> Date: Tue, 12 Aug 2025 18:16:53 -0400 Subject: [PATCH 029/104] feat: Add configuration option to MCP Client for server init timeout (#657) Co-authored-by: Harry Wilton --- src/strands/tools/mcp/mcp_client.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index c1aa96df3..7cb03e46f 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -63,17 +63,23 @@ class MCPClient: from MCP tools, it will be returned as the last item in the content array of the ToolResult. """ - def __init__(self, transport_callable: Callable[[], MCPTransport]): + def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_timeout: int = 30): """Initialize a new MCP Server connection. Args: transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple + startup_timeout: Timeout after which MCP server initialization should be cancelled + Defaults to 30. """ + self._startup_timeout = startup_timeout + mcp_instrumentation() self._session_id = uuid.uuid4() self._log_debug_with_thread("initializing MCPClient connection") - self._init_future: futures.Future[None] = futures.Future() # Main thread blocks until future completes - self._close_event = asyncio.Event() # Do not want to block other threads while close event is false + # Main thread blocks until future completesock + self._init_future: futures.Future[None] = futures.Future() + # Do not want to block other threads while close event is false + self._close_event = asyncio.Event() self._transport_callable = transport_callable self._background_thread: threading.Thread | None = None @@ -109,7 +115,7 @@ def start(self) -> "MCPClient": self._log_debug_with_thread("background thread started, waiting for ready event") try: # Blocking main thread until session is initialized in other thread or if the thread stops - self._init_future.result(timeout=30) + self._init_future.result(timeout=self._startup_timeout) self._log_debug_with_thread("the client initialization was successful") except futures.TimeoutError as e: raise MCPClientInitializationError("background thread did not start in 30 seconds") from e @@ -347,7 +353,8 @@ async def _async_background_thread(self) -> None: self._log_debug_with_thread("session initialized successfully") # Store the session for use while we await the close event self._background_thread_session = session - self._init_future.set_result(None) # Signal that the session has been created and is ready for use + # Signal that the session has been created and is ready for use + self._init_future.set_result(None) self._log_debug_with_thread("waiting for close signal") # Keep background thread running until signaled to close. From 49ff22678b27b737658d2b6215365c454bc19db6 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 13 Aug 2025 08:58:33 -0400 Subject: [PATCH 030/104] fix: Bedrock hang when exception occurs during message conversion (#643) Previously (#642) bedrock would hang during message conversion because the exception was not being caught and thus the queue was always empty. Now all exceptions during conversion are caught Co-authored-by: Mackenzie Zastrow --- pyproject.toml | 2 +- src/strands/models/bedrock.py | 12 ++++++------ tests/strands/models/test_bedrock.py | 9 +++++++++ 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 586a956af..d4a4b79dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -234,8 +234,8 @@ test-integ = [ "hatch test tests_integ {args}" ] prepare = [ - "hatch fmt --linter", "hatch fmt --formatter", + "hatch fmt --linter", "hatch run test-lint", "hatch test --all" ] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4ea1453a4..ace35640a 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -418,14 +418,14 @@ def _stream( ContextWindowOverflowException: If the input exceeds the model's context window. ModelThrottledException: If the model service is throttling requests. """ - logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) - logger.debug("request=<%s>", request) + try: + logger.debug("formatting request") + request = self.format_request(messages, tool_specs, system_prompt) + logger.debug("request=<%s>", request) - logger.debug("invoking model") - streaming = self.config.get("streaming", True) + logger.debug("invoking model") + streaming = self.config.get("streaming", True) - try: logger.debug("got response from model") if streaming: response = self.client.converse_stream(**request) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 0a2846adf..09e508845 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -419,6 +419,15 @@ async def test_stream_throttling_exception_from_event_stream_error(bedrock_clien ) +@pytest.mark.asyncio +async def test_stream_with_invalid_content_throws(bedrock_client, model, alist): + # We used to hang on None, so ensure we don't regress: https://github.com/strands-agents/sdk-python/issues/642 + messages = [{"role": "user", "content": None}] + + with pytest.raises(TypeError): + await alist(model.stream(messages)) + + @pytest.mark.asyncio async def test_stream_throttling_exception_from_general_exception(bedrock_client, model, messages, alist): error_message = "ThrottlingException: Rate exceeded for ConverseStream" From 04557562eb4345abb65bf056c2889b1586dab277 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Wed, 13 Aug 2025 17:20:34 -0400 Subject: [PATCH 031/104] feat: add structured_output_span (#655) * feat: add structured_output_span --- src/strands/agent/agent.py | 65 ++++++++++++++++++++----------- tests/strands/agent/test_agent.py | 39 +++++++++++++++++++ 2 files changed, 82 insertions(+), 22 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 2022142c6..43b5cbf8c 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -33,7 +33,7 @@ from ..models.model import Model from ..session.session_manager import SessionManager from ..telemetry.metrics import EventLoopMetrics -from ..telemetry.tracer import get_tracer +from ..telemetry.tracer import get_tracer, serialize from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages @@ -445,27 +445,48 @@ async def structured_output_async( ValueError: If no conversation history or prompt is provided. """ self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) - - try: - if not self.messages and not prompt: - raise ValueError("No conversation history or prompt provided") - - # Create temporary messages array if prompt is provided - if prompt: - content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt - temp_messages = self.messages + [{"role": "user", "content": content}] - else: - temp_messages = self.messages - - events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) - async for event in events: - if "callback" in event: - self.callback_handler(**cast(dict, event["callback"])) - - return event["output"] - - finally: - self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + with self.tracer.tracer.start_as_current_span( + "execute_structured_output", kind=trace_api.SpanKind.CLIENT + ) as structured_output_span: + try: + if not self.messages and not prompt: + raise ValueError("No conversation history or prompt provided") + # Create temporary messages array if prompt is provided + if prompt: + content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt + temp_messages = self.messages + [{"role": "user", "content": content}] + else: + temp_messages = self.messages + + structured_output_span.set_attributes( + { + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": self.name, + "gen_ai.agent.id": self.agent_id, + "gen_ai.operation.name": "execute_structured_output", + } + ) + for message in temp_messages: + structured_output_span.add_event( + f"gen_ai.{message['role']}.message", + attributes={"role": message["role"], "content": serialize(message["content"])}, + ) + if self.system_prompt: + structured_output_span.add_event( + "gen_ai.system.message", + attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])}, + ) + events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) + async for event in events: + if "callback" in event: + self.callback_handler(**cast(dict, event["callback"])) + structured_output_span.add_event( + "gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())} + ) + return event["output"] + + finally: + self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index c27243dfe..fdce7c368 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -980,6 +980,14 @@ def test_agent_callback_handler_custom_handler_used(): def test_agent_structured_output(agent, system_prompt, user, agenerator): + # Setup mock tracer and span + mock_strands_tracer = unittest.mock.MagicMock() + mock_otel_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_strands_tracer.tracer = mock_otel_tracer + mock_otel_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + agent.tracer = mock_strands_tracer + agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) prompt = "Jane Doe is 30 years old and her email is jane@doe.com" @@ -999,8 +1007,34 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt ) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "Strands Agents", + "gen_ai.agent.id": "default", + "gen_ai.operation.name": "execute_structured_output", + } + ) + + mock_span.add_event.assert_any_call( + "gen_ai.user.message", + attributes={"role": "user", "content": '[{"text": "Jane Doe is 30 years old and her email is jane@doe.com"}]'}, + ) + + mock_span.add_event.assert_called_with( + "gen_ai.choice", + attributes={"message": json.dumps(user.model_dump())}, + ) + def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, agenerator): + # Setup mock tracer and span + mock_strands_tracer = unittest.mock.MagicMock() + mock_otel_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_strands_tracer.tracer = mock_otel_tracer + mock_otel_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + agent.tracer = mock_strands_tracer agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) prompt = [ @@ -1030,6 +1064,11 @@ def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, a type(user), [{"role": "user", "content": prompt}], system_prompt=system_prompt ) + mock_span.add_event.assert_called_with( + "gen_ai.choice", + attributes={"message": json.dumps(user.model_dump())}, + ) + @pytest.mark.asyncio async def test_agent_structured_output_in_async_context(agent, user, agenerator): From 1c7257bc9e2356d025c5fa77f6a3b1e959809964 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 14 Aug 2025 10:32:58 -0400 Subject: [PATCH 032/104] litellm - set 1.73.1 as minimum version (#668) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d4a4b79dc..487b26691 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ docs = [ "sphinx-autodoc-typehints>=1.12.0,<2.0.0", ] litellm = [ - "litellm>=1.72.6,<1.73.0", + "litellm>=1.73.1,<2.0.0", ] llamaapi = [ "llama-api-client>=0.1.0,<1.0.0", From 606f65756668274d3acf2600b76df10745a08f1f Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 14 Aug 2025 14:49:08 -0400 Subject: [PATCH 033/104] feat: expose tool_use and agent through ToolContext to decorated tools (#557) --- src/strands/__init__.py | 3 +- src/strands/tools/decorator.py | 82 +++++++++-- src/strands/types/tools.py | 27 +++- tests/strands/tools/test_decorator.py | 159 ++++++++++++++++++++- tests_integ/test_tool_context_injection.py | 56 ++++++++ 5 files changed, 312 insertions(+), 15 deletions(-) create mode 100644 tests_integ/test_tool_context_injection.py diff --git a/src/strands/__init__.py b/src/strands/__init__.py index e9f9e9cd8..ae784a58f 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -3,5 +3,6 @@ from . import agent, models, telemetry, types from .agent.agent import Agent from .tools.decorator import tool +from .types.tools import ToolContext -__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry"] +__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry", "ToolContext"] diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 5ec324b68..75abac9ed 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -61,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: from pydantic import BaseModel, Field, create_model from typing_extensions import override -from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolSpec, ToolUse +from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -84,16 +84,18 @@ class FunctionToolMetadata: validate tool usage. """ - def __init__(self, func: Callable[..., Any]) -> None: + def __init__(self, func: Callable[..., Any], context_param: str | None = None) -> None: """Initialize with the function to process. Args: func: The function to extract metadata from. Can be a standalone function or a class method. + context_param: Name of the context parameter to inject, if any. """ self.func = func self.signature = inspect.signature(func) self.type_hints = get_type_hints(func) + self._context_param = context_param # Parse the docstring with docstring_parser doc_str = inspect.getdoc(func) or "" @@ -113,7 +115,7 @@ def _create_input_model(self) -> Type[BaseModel]: This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can validate input data before passing it to the function. - Special parameters like 'self', 'cls', and 'agent' are excluded from the model. + Special parameters that can be automatically injected are excluded from the model. Returns: A Pydantic BaseModel class customized for the function's parameters. @@ -121,8 +123,8 @@ def _create_input_model(self) -> Type[BaseModel]: field_definitions: dict[str, Any] = {} for name, param in self.signature.parameters.items(): - # Skip special parameters - if name in ("self", "cls", "agent"): + # Skip parameters that will be automatically injected + if self._is_special_parameter(name): continue # Get parameter type and default @@ -252,6 +254,49 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]: error_msg = str(e) raise ValueError(f"Validation failed for input parameters: {error_msg}") from e + def inject_special_parameters( + self, validated_input: dict[str, Any], tool_use: ToolUse, invocation_state: dict[str, Any] + ) -> None: + """Inject special framework-provided parameters into the validated input. + + This method automatically provides framework-level context to tools that request it + through their function signature. + + Args: + validated_input: The validated input parameters (modified in place). + tool_use: The tool use request containing tool invocation details. + invocation_state: Context for the tool invocation, including agent state. + """ + if self._context_param and self._context_param in self.signature.parameters: + tool_context = ToolContext(tool_use=tool_use, agent=invocation_state["agent"]) + validated_input[self._context_param] = tool_context + + # Inject agent if requested (backward compatibility) + if "agent" in self.signature.parameters and "agent" in invocation_state: + validated_input["agent"] = invocation_state["agent"] + + def _is_special_parameter(self, param_name: str) -> bool: + """Check if a parameter should be automatically injected by the framework or is a standard Python method param. + + Special parameters include: + - Standard Python method parameters: self, cls + - Framework-provided context parameters: agent, and configurable context parameter (defaults to tool_context) + + Args: + param_name: The name of the parameter to check. + + Returns: + True if the parameter should be excluded from input validation and + handled specially during tool execution. + """ + special_params = {"self", "cls", "agent"} + + # Add context parameter if configured + if self._context_param: + special_params.add(self._context_param) + + return param_name in special_params + P = ParamSpec("P") # Captures all parameters R = TypeVar("R") # Return type @@ -402,9 +447,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw # Validate input against the Pydantic model validated_input = self._metadata.validate_input(tool_input) - # Pass along the agent if provided and expected by the function - if "agent" in invocation_state and "agent" in self._metadata.signature.parameters: - validated_input["agent"] = invocation_state.get("agent") + # Inject special framework-provided parameters + self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state) # "Too few arguments" expected, hence the type ignore if inspect.iscoroutinefunction(self._tool_func): @@ -474,6 +518,7 @@ def tool( description: Optional[str] = None, inputSchema: Optional[JSONSchema] = None, name: Optional[str] = None, + context: bool | str = False, ) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ... # Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the # call site, but the actual implementation handles that and it's not representable via the type-system @@ -482,6 +527,7 @@ def tool( # type: ignore description: Optional[str] = None, inputSchema: Optional[JSONSchema] = None, name: Optional[str] = None, + context: bool | str = False, ) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]: """Decorator that transforms a Python function into a Strands tool. @@ -507,6 +553,9 @@ def tool( # type: ignore description: Optional custom description to override the function's docstring. inputSchema: Optional custom JSON schema to override the automatically generated schema. name: Optional custom name to override the function's name. + context: When provided, places an object in the designated parameter. If True, the param name + defaults to 'tool_context', or if an override is needed, set context equal to a string to designate + the param name. Returns: An AgentTool that also mimics the original function when invoked @@ -536,15 +585,24 @@ def my_tool(name: str, count: int = 1) -> str: Example with parameters: ```python - @tool(name="custom_tool", description="A tool with a custom name and description") - def my_tool(name: str, count: int = 1) -> str: - return f"Processed {name} {count} times" + @tool(name="custom_tool", description="A tool with a custom name and description", context=True) + def my_tool(name: str, count: int = 1, tool_context: ToolContext) -> str: + tool_id = tool_context["tool_use"]["toolUseId"] + return f"Processed {name} {count} times with tool ID {tool_id}" ``` """ def decorator(f: T) -> "DecoratedFunctionTool[P, R]": + # Resolve context parameter name + if isinstance(context, bool): + context_param = "tool_context" if context else None + else: + context_param = context.strip() + if not context_param: + raise ValueError("Context parameter name cannot be empty") + # Create function tool metadata - tool_meta = FunctionToolMetadata(f) + tool_meta = FunctionToolMetadata(f, context_param) tool_spec = tool_meta.extract_metadata() if name is not None: tool_spec["name"] = name diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 533e5529c..bb7c874f6 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -6,12 +6,16 @@ """ from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union from typing_extensions import TypedDict from .media import DocumentContent, ImageContent +if TYPE_CHECKING: + from .. import Agent + JSONSchema = dict """Type alias for JSON Schema dictionaries.""" @@ -117,6 +121,27 @@ class ToolChoiceTool(TypedDict): name: str +@dataclass +class ToolContext: + """Context object containing framework-provided data for decorated tools. + + This object provides access to framework-level information that may be useful + for tool implementations. + + Attributes: + tool_use: The complete ToolUse object containing tool invocation details. + agent: The Agent instance executing this tool, providing access to conversation history, + model configuration, and other agent state. + + Note: + This class is intended to be instantiated by the SDK. Direct construction by users + is not supported and may break in future versions as new fields are added. + """ + + tool_use: ToolUse + agent: "Agent" + + ToolChoice = Union[ dict[Literal["auto"], ToolChoiceAuto], dict[Literal["any"], ToolChoiceAny], diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 52a9282e0..246879da7 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -8,7 +8,8 @@ import pytest import strands -from strands.types.tools import ToolUse +from strands import Agent +from strands.types.tools import AgentTool, ToolContext, ToolUse @pytest.fixture(scope="module") @@ -1036,3 +1037,159 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None] result = (await alist(stream))[-1] assert result["status"] == "success" assert "NoneType: None" in result["content"][0]["text"] + + +async def _run_context_injection_test(context_tool: AgentTool): + """Common test logic for context injection tests.""" + tool: AgentTool = context_tool + generator = tool.stream( + tool_use={ + "toolUseId": "test-id", + "name": "context_tool", + "input": { + "message": "some_message" # note that we do not include agent nor tool context + }, + }, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + tool_results = [value async for value in generator] + + assert len(tool_results) == 1 + tool_result = tool_results[0] + + assert tool_result == { + "status": "success", + "content": [ + {"text": "Tool 'context_tool' (ID: test-id)"}, + {"text": "injected agent 'test_agent' processed: some_message"}, + {"text": "context agent 'test_agent'"} + ], + "toolUseId": "test-id", + } + + +@pytest.mark.asyncio +async def test_tool_context_injection_default(): + """Test that ToolContext is properly injected with default parameter name (tool_context).""" + + @strands.tool(context=True) + def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict: + """Tool that uses ToolContext to access tool_use_id.""" + tool_use_id = tool_context.tool_use["toolUseId"] + tool_name = tool_context.tool_use["name"] + agent_from_tool_context = tool_context.agent + + return { + "status": "success", + "content": [ + {"text": f"Tool '{tool_name}' (ID: {tool_use_id})"}, + {"text": f"injected agent '{agent.name}' processed: {message}"}, + {"text": f"context agent '{agent_from_tool_context.name}'"}, + ], + } + + await _run_context_injection_test(context_tool) + + +@pytest.mark.asyncio +async def test_tool_context_injection_custom_name(): + """Test that ToolContext is properly injected with custom parameter name.""" + + @strands.tool(context="custom_context_name") + def context_tool(message: str, agent: Agent, custom_context_name: ToolContext) -> dict: + """Tool that uses ToolContext to access tool_use_id.""" + tool_use_id = custom_context_name.tool_use["toolUseId"] + tool_name = custom_context_name.tool_use["name"] + agent_from_tool_context = custom_context_name.agent + + return { + "status": "success", + "content": [ + {"text": f"Tool '{tool_name}' (ID: {tool_use_id})"}, + {"text": f"injected agent '{agent.name}' processed: {message}"}, + {"text": f"context agent '{agent_from_tool_context.name}'"}, + ], + } + + await _run_context_injection_test(context_tool) + + +@pytest.mark.asyncio +async def test_tool_context_injection_disabled_missing_parameter(): + """Test that when context=False, missing tool_context parameter causes validation error.""" + + @strands.tool(context=False) + def context_tool(message: str, agent: Agent, tool_context: str) -> dict: + """Tool that expects tool_context as a regular string parameter.""" + return { + "status": "success", + "content": [ + {"text": f"Message: {message}"}, + {"text": f"Agent: {agent.name}"}, + {"text": f"Tool context string: {tool_context}"}, + ], + } + + # Verify that missing tool_context parameter causes validation error + tool: AgentTool = context_tool + generator = tool.stream( + tool_use={ + "toolUseId": "test-id", + "name": "context_tool", + "input": { + "message": "some_message" + # Missing tool_context parameter - should cause validation error instead of being auto injected + }, + }, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + tool_results = [value async for value in generator] + + assert len(tool_results) == 1 + tool_result = tool_results[0] + + # Should get a validation error because tool_context is required but not provided + assert tool_result["status"] == "error" + assert "tool_context" in tool_result["content"][0]["text"].lower() + assert "validation" in tool_result["content"][0]["text"].lower() + + +@pytest.mark.asyncio +async def test_tool_context_injection_disabled_string_parameter(): + """Test that when context=False, tool_context can be passed as a string parameter.""" + + @strands.tool(context=False) + def context_tool(message: str, agent: Agent, tool_context: str) -> str: + """Tool that expects tool_context as a regular string parameter.""" + return "success" + + # Verify that providing tool_context as a string works correctly + tool: AgentTool = context_tool + generator = tool.stream( + tool_use={ + "toolUseId": "test-id-2", + "name": "context_tool", + "input": { + "message": "some_message", + "tool_context": "my_custom_context_string" + }, + }, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + tool_results = [value async for value in generator] + + assert len(tool_results) == 1 + tool_result = tool_results[0] + + # Should succeed with the string parameter + assert tool_result == { + "status": "success", + "content": [{"text": "success"}], + "toolUseId": "test-id-2", + } diff --git a/tests_integ/test_tool_context_injection.py b/tests_integ/test_tool_context_injection.py new file mode 100644 index 000000000..3098604f1 --- /dev/null +++ b/tests_integ/test_tool_context_injection.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +""" +Integration test for ToolContext functionality with real agent interactions. +""" + +from strands import Agent, ToolContext, tool +from strands.types.tools import ToolResult + + +@tool(context="custom_context_field") +def good_story(message: str, custom_context_field: ToolContext) -> dict: + """Tool that writes a good story""" + tool_use_id = custom_context_field.tool_use["toolUseId"] + return { + "status": "success", + "content": [{"text": f"Context tool processed with ID: {tool_use_id}"}], + } + + +@tool(context=True) +def bad_story(message: str, tool_context: ToolContext) -> dict: + """Tool that writes a bad story""" + tool_use_id = tool_context.tool_use["toolUseId"] + return { + "status": "success", + "content": [{"text": f"Context tool processed with ID: {tool_use_id}"}], + } + + +def _validate_tool_result_content(agent: Agent): + first_tool_result: ToolResult = [ + block["toolResult"] for message in agent.messages for block in message["content"] if "toolResult" in block + ][0] + + assert first_tool_result["status"] == "success" + assert ( + first_tool_result["content"][0]["text"] == f"Context tool processed with ID: {first_tool_result['toolUseId']}" + ) + + +def test_strands_context_integration_context_true(): + """Test ToolContext functionality with real agent interactions.""" + + agent = Agent(tools=[good_story]) + agent("using a tool, write a good story") + + _validate_tool_result_content(agent) + + +def test_strands_context_integration_context_custom(): + """Test ToolContext functionality with real agent interactions.""" + + agent = Agent(tools=[bad_story]) + agent("using a tool, write a bad story") + + _validate_tool_result_content(agent) From 8c63d75ecf9c246110d297c109bf204839978152 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 15 Aug 2025 17:50:42 -0400 Subject: [PATCH 034/104] session manager - prevent file path injection (#680) --- src/strands/_identifier.py | 30 + src/strands/agent/agent.py | 6 +- src/strands/session/file_session_manager.py | 27 +- src/strands/session/s3_session_manager.py | 23 +- tests/strands/agent/test_agent.py | 12 + .../session/test_file_session_manager.py | 604 +++++++++--------- .../session/test_s3_session_manager.py | 24 + tests/strands/test_identifier.py | 17 + tests/strands/tools/test_decorator.py | 9 +- 9 files changed, 452 insertions(+), 300 deletions(-) create mode 100644 src/strands/_identifier.py create mode 100644 tests/strands/test_identifier.py diff --git a/src/strands/_identifier.py b/src/strands/_identifier.py new file mode 100644 index 000000000..e8b12635c --- /dev/null +++ b/src/strands/_identifier.py @@ -0,0 +1,30 @@ +"""Strands identifier utilities.""" + +import enum +import os + + +class Identifier(enum.Enum): + """Strands identifier types.""" + + AGENT = "agent" + SESSION = "session" + + +def validate(id_: str, type_: Identifier) -> str: + """Validate strands id. + + Args: + id_: Id to validate. + type_: Type of the identifier (e.g., session id, agent id, etc.) + + Returns: + Validated id. + + Raises: + ValueError: If id contains path separators. + """ + if os.path.basename(id_) != id_: + raise ValueError(f"{type_.value}_id={id_} | id cannot contain path separators") + + return id_ diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 43b5cbf8c..38e687af2 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -19,6 +19,7 @@ from opentelemetry import trace as trace_api from pydantic import BaseModel +from .. import _identifier from ..event_loop.event_loop import event_loop_cycle, run_tool from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( @@ -249,12 +250,15 @@ def __init__( Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. + + Raises: + ValueError: If agent id contains path separators. """ self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] self.system_prompt = system_prompt - self.agent_id = agent_id or _DEFAULT_AGENT_ID + self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME self.description = description diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index fec2f0761..9df86e17a 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -7,6 +7,7 @@ import tempfile from typing import Any, Optional, cast +from .. import _identifier from ..types.exceptions import SessionException from ..types.session import Session, SessionAgent, SessionMessage from .repository_session_manager import RepositorySessionManager @@ -40,8 +41,9 @@ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: """Initialize FileSession with filesystem storage. Args: - session_id: ID for the session - storage_dir: Directory for local filesystem storage (defaults to temp dir) + session_id: ID for the session. + ID is not allowed to contain path separators (e.g., a/b). + storage_dir: Directory for local filesystem storage (defaults to temp dir). **kwargs: Additional keyword arguments for future extensibility. """ self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions") @@ -50,12 +52,29 @@ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: super().__init__(session_id=session_id, session_repository=self) def _get_session_path(self, session_id: str) -> str: - """Get session directory path.""" + """Get session directory path. + + Args: + session_id: ID for the session. + + Raises: + ValueError: If session id contains a path separator. + """ + session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION) return os.path.join(self.storage_dir, f"{SESSION_PREFIX}{session_id}") def _get_agent_path(self, session_id: str, agent_id: str) -> str: - """Get agent directory path.""" + """Get agent directory path. + + Args: + session_id: ID for the session. + agent_id: ID for the agent. + + Raises: + ValueError: If session id or agent id contains a path separator. + """ session_path = self._get_session_path(session_id) + agent_id = _identifier.validate(agent_id, _identifier.Identifier.AGENT) return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}") def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str: diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 0cc0a68c1..d15e6e3bd 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -8,6 +8,7 @@ from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError +from .. import _identifier from ..types.exceptions import SessionException from ..types.session import Session, SessionAgent, SessionMessage from .repository_session_manager import RepositorySessionManager @@ -51,6 +52,7 @@ def __init__( Args: session_id: ID for the session + ID is not allowed to contain path separators (e.g., a/b). bucket: S3 bucket name (required) prefix: S3 key prefix for storage organization boto_session: Optional boto3 session @@ -79,12 +81,29 @@ def __init__( super().__init__(session_id=session_id, session_repository=self) def _get_session_path(self, session_id: str) -> str: - """Get session S3 prefix.""" + """Get session S3 prefix. + + Args: + session_id: ID for the session. + + Raises: + ValueError: If session id contains a path separator. + """ + session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION) return f"{self.prefix}/{SESSION_PREFIX}{session_id}/" def _get_agent_path(self, session_id: str, agent_id: str) -> str: - """Get agent S3 prefix.""" + """Get agent S3 prefix. + + Args: + session_id: ID for the session. + agent_id: ID for the agent. + + Raises: + ValueError: If session id or agent id contains a path separator. + """ session_path = self._get_session_path(session_id) + agent_id = _identifier.validate(agent_id, _identifier.Identifier.AGENT) return f"{session_path}agents/{AGENT_PREFIX}{agent_id}/" def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index fdce7c368..ca66ca2bf 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -250,6 +250,18 @@ def test_agent__init__deeply_nested_tools(tool_decorated, tool_module, tool_impo assert tru_tool_names == exp_tool_names +@pytest.mark.parametrize( + "agent_id", + [ + "a/../b", + "a/b", + ], +) +def test_agent__init__invalid_id(agent_id): + with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): + Agent(agent_id=agent_id) + + def test_agent__call__( mock_model, system_prompt, diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index f9fc3ba94..a89222b7e 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -53,310 +53,340 @@ def sample_message(): ) -class TestFileSessionManagerSessionOperations: - """Tests for session operations.""" - - def test_create_session(self, file_manager, sample_session): - """Test creating a session.""" - file_manager.create_session(sample_session) - - # Verify directory structure created - session_path = file_manager._get_session_path(sample_session.session_id) - assert os.path.exists(session_path) - - # Verify session file created - session_file = os.path.join(session_path, "session.json") - assert os.path.exists(session_file) - - # Verify content - with open(session_file, "r") as f: - data = json.load(f) - assert data["session_id"] == sample_session.session_id - assert data["session_type"] == sample_session.session_type - - def test_read_session(self, file_manager, sample_session): - """Test reading an existing session.""" - # Create session first - file_manager.create_session(sample_session) - - # Read it back - result = file_manager.read_session(sample_session.session_id) - - assert result.session_id == sample_session.session_id - assert result.session_type == sample_session.session_type - - def test_read_nonexistent_session(self, file_manager): - """Test reading a session that doesn't exist.""" - result = file_manager.read_session("nonexistent-session") - assert result is None - - def test_delete_session(self, file_manager, sample_session): - """Test deleting a session.""" - # Create session first - file_manager.create_session(sample_session) - session_path = file_manager._get_session_path(sample_session.session_id) - assert os.path.exists(session_path) - - # Delete session - file_manager.delete_session(sample_session.session_id) - - # Verify deletion - assert not os.path.exists(session_path) - - def test_delete_nonexistent_session(self, file_manager): - """Test deleting a session that doesn't exist.""" - # Should raise an error according to the implementation - with pytest.raises(SessionException, match="does not exist"): - file_manager.delete_session("nonexistent-session") - - -class TestFileSessionManagerAgentOperations: - """Tests for agent operations.""" - - def test_create_agent(self, file_manager, sample_session, sample_agent): - """Test creating an agent in a session.""" - # Create session first - file_manager.create_session(sample_session) - - # Create agent - file_manager.create_agent(sample_session.session_id, sample_agent) - - # Verify directory structure - agent_path = file_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id) - assert os.path.exists(agent_path) - - # Verify agent file - agent_file = os.path.join(agent_path, "agent.json") - assert os.path.exists(agent_file) - - # Verify content - with open(agent_file, "r") as f: - data = json.load(f) - assert data["agent_id"] == sample_agent.agent_id - assert data["state"] == sample_agent.state - - def test_read_agent(self, file_manager, sample_session, sample_agent): - """Test reading an agent from a session.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - # Read agent - result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) - - assert result.agent_id == sample_agent.agent_id - assert result.state == sample_agent.state - - def test_read_nonexistent_agent(self, file_manager, sample_session): - """Test reading an agent that doesn't exist.""" - result = file_manager.read_agent(sample_session.session_id, "nonexistent_agent") - assert result is None - - def test_update_agent(self, file_manager, sample_session, sample_agent): - """Test updating an agent.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - # Update agent - sample_agent.state = {"updated": "value"} +def test_create_session(file_manager, sample_session): + """Test creating a session.""" + file_manager.create_session(sample_session) + + # Verify directory structure created + session_path = file_manager._get_session_path(sample_session.session_id) + assert os.path.exists(session_path) + + # Verify session file created + session_file = os.path.join(session_path, "session.json") + assert os.path.exists(session_file) + + # Verify content + with open(session_file, "r") as f: + data = json.load(f) + assert data["session_id"] == sample_session.session_id + assert data["session_type"] == sample_session.session_type + + +def test_read_session(file_manager, sample_session): + """Test reading an existing session.""" + # Create session first + file_manager.create_session(sample_session) + + # Read it back + result = file_manager.read_session(sample_session.session_id) + + assert result.session_id == sample_session.session_id + assert result.session_type == sample_session.session_type + + +def test_read_nonexistent_session(file_manager): + """Test reading a session that doesn't exist.""" + result = file_manager.read_session("nonexistent-session") + assert result is None + + +def test_delete_session(file_manager, sample_session): + """Test deleting a session.""" + # Create session first + file_manager.create_session(sample_session) + session_path = file_manager._get_session_path(sample_session.session_id) + assert os.path.exists(session_path) + + # Delete session + file_manager.delete_session(sample_session.session_id) + + # Verify deletion + assert not os.path.exists(session_path) + + +def test_delete_nonexistent_session(file_manager): + """Test deleting a session that doesn't exist.""" + # Should raise an error according to the implementation + with pytest.raises(SessionException, match="does not exist"): + file_manager.delete_session("nonexistent-session") + + +def test_create_agent(file_manager, sample_session, sample_agent): + """Test creating an agent in a session.""" + # Create session first + file_manager.create_session(sample_session) + + # Create agent + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Verify directory structure + agent_path = file_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id) + assert os.path.exists(agent_path) + + # Verify agent file + agent_file = os.path.join(agent_path, "agent.json") + assert os.path.exists(agent_file) + + # Verify content + with open(agent_file, "r") as f: + data = json.load(f) + assert data["agent_id"] == sample_agent.agent_id + assert data["state"] == sample_agent.state + + +def test_read_agent(file_manager, sample_session, sample_agent): + """Test reading an agent from a session.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Read agent + result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + + assert result.agent_id == sample_agent.agent_id + assert result.state == sample_agent.state + + +def test_read_nonexistent_agent(file_manager, sample_session): + """Test reading an agent that doesn't exist.""" + result = file_manager.read_agent(sample_session.session_id, "nonexistent_agent") + assert result is None + + +def test_update_agent(file_manager, sample_session, sample_agent): + """Test updating an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Update agent + sample_agent.state = {"updated": "value"} + file_manager.update_agent(sample_session.session_id, sample_agent) + + # Verify update + result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + assert result.state == {"updated": "value"} + + +def test_update_nonexistent_agent(file_manager, sample_session, sample_agent): + """Test updating an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + + # Update agent + with pytest.raises(SessionException): file_manager.update_agent(sample_session.session_id, sample_agent) - # Verify update - result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) - assert result.state == {"updated": "value"} - def test_update_nonexistent_agent(self, file_manager, sample_session, sample_agent): - """Test updating an agent.""" - # Create session and agent - file_manager.create_session(sample_session) +def test_create_message(file_manager, sample_session, sample_agent, sample_message): + """Test creating a message for an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) - # Update agent - with pytest.raises(SessionException): - file_manager.update_agent(sample_session.session_id, sample_agent) + # Create message + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify message file + message_path = file_manager._get_message_path( + sample_session.session_id, sample_agent.agent_id, sample_message.message_id + ) + assert os.path.exists(message_path) + # Verify content + with open(message_path, "r") as f: + data = json.load(f) + assert data["message_id"] == sample_message.message_id -class TestFileSessionManagerMessageOperations: - """Tests for message operations.""" - def test_create_message(self, file_manager, sample_session, sample_agent, sample_message): - """Test creating a message for an agent.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) +def test_read_message(file_manager, sample_session, sample_agent, sample_message): + """Test reading a message.""" + # Create session, agent, and message + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) - # Create message - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + # Create multiple messages when reading + sample_message.message_id = sample_message.message_id + 1 + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) - # Verify message file - message_path = file_manager._get_message_path( - sample_session.session_id, sample_agent.agent_id, sample_message.message_id + # Read message + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + + assert result.message_id == sample_message.message_id + assert result.message["role"] == sample_message.message["role"] + assert result.message["content"] == sample_message.message["content"] + + +def test_read_messages_with_new_agent(file_manager, sample_session, sample_agent): + """Test reading a message with with a new agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + + assert result is None + + +def test_read_nonexistent_message(file_manager, sample_session, sample_agent): + """Test reading a message that doesnt exist.""" + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + assert result is None + + +def test_list_messages_all(file_manager, sample_session, sample_agent): + """Test listing all messages for an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + messages = [] + for i in range(5): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + }, + message_id=i, ) - assert os.path.exists(message_path) - - # Verify content - with open(message_path, "r") as f: - data = json.load(f) - assert data["message_id"] == sample_message.message_id - - def test_read_message(self, file_manager, sample_session, sample_agent, sample_message): - """Test reading a message.""" - # Create session, agent, and message - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) - - # Create multiple messages when reading - sample_message.message_id = sample_message.message_id + 1 - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) - - # Read message - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) - - assert result.message_id == sample_message.message_id - assert result.message["role"] == sample_message.message["role"] - assert result.message["content"] == sample_message.message["content"] - - def test_read_messages_with_new_agent(self, file_manager, sample_session, sample_agent): - """Test reading a message with with a new agent.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") - - assert result is None - - def test_read_nonexistent_message(self, file_manager, sample_session, sample_agent): - """Test reading a message that doesnt exist.""" - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") - assert result is None - - def test_list_messages_all(self, file_manager, sample_session, sample_agent): - """Test listing all messages for an agent.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - # Create multiple messages - messages = [] - for i in range(5): - message = SessionMessage( - message={ - "role": "user", - "content": [ContentBlock(text=f"Message {i}")], - }, - message_id=i, - ) - messages.append(message) - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) - - # List all messages - result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) - - assert len(result) == 5 - - def test_list_messages_with_limit(self, file_manager, sample_session, sample_agent): - """Test listing messages with limit.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - # Create multiple messages - for i in range(10): - message = SessionMessage( - message={ - "role": "user", - "content": [ContentBlock(text=f"Message {i}")], - }, - message_id=i, - ) - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) - - # List with limit - result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3) - - assert len(result) == 3 - - def test_list_messages_with_offset(self, file_manager, sample_session, sample_agent): - """Test listing messages with offset.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - # Create multiple messages - for i in range(10): - message = SessionMessage( - message={ - "role": "user", - "content": [ContentBlock(text=f"Message {i}")], - }, - message_id=i, - ) - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) - - # List with offset - result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5) - - assert len(result) == 5 - - def test_list_messages_with_new_agent(self, file_manager, sample_session, sample_agent): - """Test listing messages with new agent.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) - - assert len(result) == 0 - - def test_update_message(self, file_manager, sample_session, sample_agent, sample_message): - """Test updating a message.""" - # Create session, agent, and message - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) - - # Update message - sample_message.message["content"] = [ContentBlock(text="Updated content")] - file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + messages.append(message) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) - # Verify update - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) - assert result.message["content"][0]["text"] == "Updated content" + # List all messages + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) - def test_update_nonexistent_message(self, file_manager, sample_session, sample_agent, sample_message): - """Test updating a message.""" - # Create session, agent, and message - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) + assert len(result) == 5 + + +def test_list_messages_with_limit(file_manager, sample_session, sample_agent): + """Test listing messages with limit.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for i in range(10): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + }, + message_id=i, + ) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with limit + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3) + + assert len(result) == 3 - # Update nonexistent message - with pytest.raises(SessionException): - file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) +def test_list_messages_with_offset(file_manager, sample_session, sample_agent): + """Test listing messages with offset.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) -class TestFileSessionManagerErrorHandling: - """Tests for error handling scenarios.""" + # Create multiple messages + for i in range(10): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + }, + message_id=i, + ) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with offset + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5) + + assert len(result) == 5 + + +def test_list_messages_with_new_agent(file_manager, sample_session, sample_agent): + """Test listing messages with new agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 0 + + +def test_update_message(file_manager, sample_session, sample_agent, sample_message): + """Test updating a message.""" + # Create session, agent, and message + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) - def test_corrupted_json_file(self, file_manager, temp_dir): - """Test handling of corrupted JSON files.""" - # Create a corrupted session file - session_path = os.path.join(temp_dir, "session_test") - os.makedirs(session_path, exist_ok=True) - session_file = os.path.join(session_path, "session.json") + # Update message + sample_message.message["content"] = [ContentBlock(text="Updated content")] + file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) - with open(session_file, "w") as f: - f.write("invalid json content") + # Verify update + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + assert result.message["content"][0]["text"] == "Updated content" - # Should raise SessionException - with pytest.raises(SessionException, match="Invalid JSON"): - file_manager._read_file(session_file) - def test_permission_error_handling(self, file_manager): - """Test handling of permission errors.""" - with patch("builtins.open", side_effect=PermissionError("Access denied")): - session = Session(session_id="test", session_type=SessionType.AGENT) +def test_update_nonexistent_message(file_manager, sample_session, sample_agent, sample_message): + """Test updating a message.""" + # Create session, agent, and message + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) - with pytest.raises(SessionException): - file_manager.create_session(session) + # Update nonexistent message + with pytest.raises(SessionException): + file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + +def test_corrupted_json_file(file_manager, temp_dir): + """Test handling of corrupted JSON files.""" + # Create a corrupted session file + session_path = os.path.join(temp_dir, "session_test") + os.makedirs(session_path, exist_ok=True) + session_file = os.path.join(session_path, "session.json") + + with open(session_file, "w") as f: + f.write("invalid json content") + + # Should raise SessionException + with pytest.raises(SessionException, match="Invalid JSON"): + file_manager._read_file(session_file) + + +def test_permission_error_handling(file_manager): + """Test handling of permission errors.""" + with patch("builtins.open", side_effect=PermissionError("Access denied")): + session = Session(session_id="test", session_type=SessionType.AGENT) + + with pytest.raises(SessionException): + file_manager.create_session(session) + + +@pytest.mark.parametrize( + "session_id", + [ + "a/../b", + "a/b", + ], +) +def test__get_session_path_invalid_session_id(session_id, file_manager): + with pytest.raises(ValueError, match=f"session_id={session_id} | id cannot contain path separators"): + file_manager._get_session_path(session_id) + + +@pytest.mark.parametrize( + "agent_id", + [ + "a/../b", + "a/b", + ], +) +def test__get_agent_path_invalid_agent_id(agent_id, file_manager): + with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): + file_manager._get_agent_path("session1", agent_id) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index fadd0db4b..71bff3050 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -332,3 +332,27 @@ def test_update_nonexistent_message(s3_manager, sample_session, sample_agent, sa # Update message with pytest.raises(SessionException): s3_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + +@pytest.mark.parametrize( + "session_id", + [ + "a/../b", + "a/b", + ], +) +def test__get_session_path_invalid_session_id(session_id, s3_manager): + with pytest.raises(ValueError, match=f"session_id={session_id} | id cannot contain path separators"): + s3_manager._get_session_path(session_id) + + +@pytest.mark.parametrize( + "agent_id", + [ + "a/../b", + "a/b", + ], +) +def test__get_agent_path_invalid_agent_id(agent_id, s3_manager): + with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): + s3_manager._get_agent_path("session1", agent_id) diff --git a/tests/strands/test_identifier.py b/tests/strands/test_identifier.py new file mode 100644 index 000000000..df673baa8 --- /dev/null +++ b/tests/strands/test_identifier.py @@ -0,0 +1,17 @@ +import pytest + +from strands import _identifier + + +@pytest.mark.parametrize("type_", list(_identifier.Identifier)) +def test_validate(type_): + tru_id = _identifier.validate("abc", type_) + exp_id = "abc" + assert tru_id == exp_id + + +@pytest.mark.parametrize("type_", list(_identifier.Identifier)) +def test_validate_invalid(type_): + id_ = "a/../b" + with pytest.raises(ValueError, match=f"{type_.value}={id_} | id cannot contain path separators"): + _identifier.validate(id_, type_) diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 246879da7..e490c7bb0 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1064,7 +1064,7 @@ async def _run_context_injection_test(context_tool: AgentTool): "content": [ {"text": "Tool 'context_tool' (ID: test-id)"}, {"text": "injected agent 'test_agent' processed: some_message"}, - {"text": "context agent 'test_agent'"} + {"text": "context agent 'test_agent'"}, ], "toolUseId": "test-id", } @@ -1151,7 +1151,7 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> dict: assert len(tool_results) == 1 tool_result = tool_results[0] - + # Should get a validation error because tool_context is required but not provided assert tool_result["status"] == "error" assert "tool_context" in tool_result["content"][0]["text"].lower() @@ -1173,10 +1173,7 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str: tool_use={ "toolUseId": "test-id-2", "name": "context_tool", - "input": { - "message": "some_message", - "tool_context": "my_custom_context_string" - }, + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, }, invocation_state={ "agent": Agent(name="test_agent"), From fbd598a0abea3d1b5a9781f7cdb5819ed81f51ca Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Mon, 18 Aug 2025 06:21:15 -0700 Subject: [PATCH 035/104] fix: only set signature in message if signature was provided by the model (#682) --- src/strands/event_loop/streaming.py | 19 ++++++++++--------- tests/strands/event_loop/test_streaming.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 74cadaf9e..f4048a65c 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -194,16 +194,18 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: state["text"] = "" elif reasoning_text: - content.append( - { - "reasoningContent": { - "reasoningText": { - "text": state["reasoningText"], - "signature": state["signature"], - } + content_block: ContentBlock = { + "reasoningContent": { + "reasoningText": { + "text": state["reasoningText"], } } - ) + } + + if "signature" in state: + content_block["reasoningContent"]["reasoningText"]["signature"] = state["signature"] + + content.append(content_block) state["reasoningText"] = "" return state @@ -263,7 +265,6 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d "text": "", "current_tool_use": {}, "reasoningText": "", - "signature": "", } state["content"] = state["message"]["content"] diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 921fd91de..b1cc312c2 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -216,6 +216,21 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "signature": "123", }, ), + # Reasoning without signature + ( + { + "content": [], + "current_tool_use": {}, + "text": "", + "reasoningText": "test", + }, + { + "content": [{"reasoningContent": {"reasoningText": {"text": "test"}}}], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + }, + ), # Empty ( { From ae74aa33ed9502c72f7d0f46757ec1c5a91fcb00 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Mon, 18 Aug 2025 10:03:33 -0400 Subject: [PATCH 036/104] fix: Add openai dependency to sagemaker dependency group (#678) It depends on OpenAI and we a got a report about the need to install it explicitly Co-authored-by: Mackenzie Zastrow --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 487b26691..6c0b6e3f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,9 @@ writer = [ sagemaker = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", - "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0" + "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", + # uses OpenAI as part of the implementation + "openai>=1.68.0,<2.0.0", ] a2a = [ From 980a988f4cc3b580d37359f3646d2b603715ad69 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Mon, 18 Aug 2025 15:10:58 -0400 Subject: [PATCH 037/104] Have [all] group reference the other optional dependency groups by name (#674) --- pyproject.toml | 49 ++++--------------------------------------------- 1 file changed, 4 insertions(+), 45 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6c0b6e3f7..847db8d2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,8 @@ docs = [ ] litellm = [ "litellm>=1.73.1,<2.0.0", + # https://github.com/BerriAI/litellm/issues/13711 + "openai<1.100.0", ] llamaapi = [ "llama-api-client>=0.1.0,<1.0.0", @@ -106,50 +108,7 @@ a2a = [ "starlette>=0.46.2,<1.0.0", ] all = [ - # anthropic - "anthropic>=0.21.0,<1.0.0", - - # dev - "commitizen>=4.4.0,<5.0.0", - "hatch>=1.0.0,<2.0.0", - "moto>=5.1.0,<6.0.0", - "mypy>=1.15.0,<2.0.0", - "pre-commit>=3.2.0,<4.2.0", - "pytest>=8.0.0,<9.0.0", - "pytest-asyncio>=0.26.0,<0.27.0", - "pytest-cov>=4.1.0,<5.0.0", - "pytest-xdist>=3.0.0,<4.0.0", - "ruff>=0.4.4,<0.5.0", - - # docs - "sphinx>=5.0.0,<6.0.0", - "sphinx-rtd-theme>=1.0.0,<2.0.0", - "sphinx-autodoc-typehints>=1.12.0,<2.0.0", - - # litellm - "litellm>=1.72.6,<1.73.0", - - # llama - "llama-api-client>=0.1.0,<1.0.0", - - # mistral - "mistralai>=1.8.2", - - # ollama - "ollama>=0.4.8,<1.0.0", - - # openai - "openai>=1.68.0,<2.0.0", - - # otel - "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", - - # a2a - "a2a-sdk[sql]>=0.3.0,<0.4.0", - "uvicorn>=0.34.2,<1.0.0", - "httpx>=0.28.1,<1.0.0", - "fastapi>=0.115.12,<1.0.0", - "starlette>=0.46.2,<1.0.0", + "strands-agents[a2a,anthropic,dev,docs,litellm,llamaapi,mistral,ollama,openai,otel]", ] [tool.hatch.version] @@ -161,7 +120,7 @@ features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mis dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", - "strands-agents @ {root:uri}" + "strands-agents @ {root:uri}", ] [tool.hatch.envs.hatch-static-analysis.scripts] From b1df148fbc89bb057348a897ea42fa3c6501ac63 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Mon, 18 Aug 2025 16:06:28 -0400 Subject: [PATCH 038/104] fix: append blank text content if assistant content is empty (#677) --- src/strands/event_loop/streaming.py | 6 +++--- tests/strands/event_loop/test_streaming.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index f4048a65c..1f8c260a4 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -40,10 +40,12 @@ def remove_blank_messages_content_text(messages: Messages) -> Messages: # only modify assistant messages if "role" in message and message["role"] != "assistant": continue - if "content" in message: content = message["content"] has_tool_use = any("toolUse" in item for item in content) + if len(content) == 0: + content.append({"text": "[blank text]"}) + continue if has_tool_use: # Remove blank 'text' items for assistant messages @@ -273,7 +275,6 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d async for chunk in chunks: yield {"callback": {"event": chunk}} - if "messageStart" in chunk: state["message"] = handle_message_start(chunk["messageStart"], state["message"]) elif "contentBlockStart" in chunk: @@ -313,7 +314,6 @@ async def stream_messages( logger.debug("model=<%s> | streaming messages", model) messages = remove_blank_messages_content_text(messages) - chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt) async for event in process_stream(chunks): diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index b1cc312c2..66deb282c 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -26,6 +26,7 @@ def moto_autouse(moto_env, moto_mock_aws): {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {}}]}, {"role": "assistant", "content": [{"text": ""}, {"toolUse": {}}]}, {"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}, + {"role": "assistant", "content": []}, {"role": "assistant"}, {"role": "user", "content": [{"text": " \n"}]}, ], @@ -33,6 +34,7 @@ def moto_autouse(moto_env, moto_mock_aws): {"role": "assistant", "content": [{"text": "a"}, {"toolUse": {}}]}, {"role": "assistant", "content": [{"toolUse": {}}]}, {"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]}, + {"role": "assistant", "content": [{"text": "[blank text]"}]}, {"role": "assistant"}, {"role": "user", "content": [{"text": " \n"}]}, ], From cfcf93dc781cc0300f3faae19530c527bbe595ad Mon Sep 17 00:00:00 2001 From: Oz Altagar Date: Tue, 19 Aug 2025 00:06:08 +0300 Subject: [PATCH 039/104] feat: add cached token metrics support for Amazon Bedrock (#531) * feat: add cached token metrics support for Amazon Bedrock - Add optional cacheReadInputTokens and cacheWriteInputTokens fields to Usage TypedDict - Update EventLoopMetrics to accumulate cached token metrics - Add OpenTelemetry instrumentation for cached token telemetry - Enhance metrics summary display to show cached token information - Maintain 100% backward compatibility with existing Usage objects - Add comprehensive test coverage for cached token functionality Resolves #529 * feat: updated cached read/write input token metrics --------- Co-authored-by: poshinchen --- src/strands/telemetry/metrics.py | 45 +++++++++++++++++++--- src/strands/telemetry/metrics_constants.py | 2 + src/strands/types/event_loop.py | 16 +++++--- tests/strands/event_loop/test_streaming.py | 12 ++++++ tests/strands/telemetry/test_metrics.py | 8 ++-- 5 files changed, 66 insertions(+), 17 deletions(-) diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index 332ab2ae3..883273f64 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -11,7 +11,7 @@ from ..telemetry import metrics_constants as constants from ..types.content import Message -from ..types.streaming import Metrics, Usage +from ..types.event_loop import Metrics, Usage from ..types.tools import ToolUse logger = logging.getLogger(__name__) @@ -264,6 +264,21 @@ def update_usage(self, usage: Usage) -> None: self.accumulated_usage["outputTokens"] += usage["outputTokens"] self.accumulated_usage["totalTokens"] += usage["totalTokens"] + # Handle optional cached token metrics + if "cacheReadInputTokens" in usage: + cache_read_tokens = usage["cacheReadInputTokens"] + self._metrics_client.event_loop_cache_read_input_tokens.record(cache_read_tokens) + self.accumulated_usage["cacheReadInputTokens"] = ( + self.accumulated_usage.get("cacheReadInputTokens", 0) + cache_read_tokens + ) + + if "cacheWriteInputTokens" in usage: + cache_write_tokens = usage["cacheWriteInputTokens"] + self._metrics_client.event_loop_cache_write_input_tokens.record(cache_write_tokens) + self.accumulated_usage["cacheWriteInputTokens"] = ( + self.accumulated_usage.get("cacheWriteInputTokens", 0) + cache_write_tokens + ) + def update_metrics(self, metrics: Metrics) -> None: """Update the accumulated performance metrics with new metrics data. @@ -325,11 +340,21 @@ def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_name f"ā”œā”€ Cycles: total={summary['total_cycles']}, avg_time={summary['average_cycle_time']:.3f}s, " f"total_time={summary['total_duration']:.3f}s" ) - yield ( - f"ā”œā”€ Tokens: in={summary['accumulated_usage']['inputTokens']}, " - f"out={summary['accumulated_usage']['outputTokens']}, " - f"total={summary['accumulated_usage']['totalTokens']}" - ) + + # Build token display with optional cached tokens + token_parts = [ + f"in={summary['accumulated_usage']['inputTokens']}", + f"out={summary['accumulated_usage']['outputTokens']}", + f"total={summary['accumulated_usage']['totalTokens']}", + ] + + # Add cached token info if present + if summary["accumulated_usage"].get("cacheReadInputTokens"): + token_parts.append(f"cache_read_input_tokens={summary['accumulated_usage']['cacheReadInputTokens']}") + if summary["accumulated_usage"].get("cacheWriteInputTokens"): + token_parts.append(f"cache_write_input_tokens={summary['accumulated_usage']['cacheWriteInputTokens']}") + + yield f"ā”œā”€ Tokens: {', '.join(token_parts)}" yield f"ā”œā”€ Bedrock Latency: {summary['accumulated_metrics']['latencyMs']}ms" yield "ā”œā”€ Tool Usage:" @@ -421,6 +446,8 @@ class MetricsClient: event_loop_latency: Histogram event_loop_input_tokens: Histogram event_loop_output_tokens: Histogram + event_loop_cache_read_input_tokens: Histogram + event_loop_cache_write_input_tokens: Histogram tool_call_count: Counter tool_success_count: Counter @@ -474,3 +501,9 @@ def create_instruments(self) -> None: self.event_loop_output_tokens = self.meter.create_histogram( name=constants.STRANDS_EVENT_LOOP_OUTPUT_TOKENS, unit="token" ) + self.event_loop_cache_read_input_tokens = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_CACHE_READ_INPUT_TOKENS, unit="token" + ) + self.event_loop_cache_write_input_tokens = self.meter.create_histogram( + name=constants.STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS, unit="token" + ) diff --git a/src/strands/telemetry/metrics_constants.py b/src/strands/telemetry/metrics_constants.py index b622eebff..f8fac34da 100644 --- a/src/strands/telemetry/metrics_constants.py +++ b/src/strands/telemetry/metrics_constants.py @@ -13,3 +13,5 @@ STRANDS_EVENT_LOOP_CYCLE_DURATION = "strands.event_loop.cycle_duration" STRANDS_EVENT_LOOP_INPUT_TOKENS = "strands.event_loop.input.tokens" STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens" +STRANDS_EVENT_LOOP_CACHE_READ_INPUT_TOKENS = "strands.event_loop.cache_read.input.tokens" +STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS = "strands.event_loop.cache_write.input.tokens" diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 7be33b6fd..2c240972b 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -2,21 +2,25 @@ from typing import Literal -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict -class Usage(TypedDict): +class Usage(TypedDict, total=False): """Token usage information for model interactions. Attributes: - inputTokens: Number of tokens sent in the request to the model.. + inputTokens: Number of tokens sent in the request to the model. outputTokens: Number of tokens that the model generated for the request. totalTokens: Total number of tokens (input + output). + cacheReadInputTokens: Number of tokens read from cache (optional). + cacheWriteInputTokens: Number of tokens written to cache (optional). """ - inputTokens: int - outputTokens: int - totalTokens: int + inputTokens: Required[int] + outputTokens: Required[int] + totalTokens: Required[int] + cacheReadInputTokens: int + cacheWriteInputTokens: int class Metrics(TypedDict): diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 66deb282c..7760c498a 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -277,6 +277,18 @@ def test_extract_usage_metrics(): assert tru_usage == exp_usage and tru_metrics == exp_metrics +def test_extract_usage_metrics_with_cache_tokens(): + event = { + "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0, "cacheReadInputTokens": 0}, + "metrics": {"latencyMs": 0}, + } + + tru_usage, tru_metrics = strands.event_loop.streaming.extract_usage_metrics(event) + exp_usage, exp_metrics = event["usage"], event["metrics"] + + assert tru_usage == exp_usage and tru_metrics == exp_metrics + + @pytest.mark.parametrize( ("response", "exp_events"), [ diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index 215e1efde..12db81908 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -90,6 +90,7 @@ def usage(request): "inputTokens": 1, "outputTokens": 2, "totalTokens": 3, + "cacheWriteInputTokens": 2, } if hasattr(request, "param"): params.update(request.param) @@ -315,17 +316,14 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_met event_loop_metrics.update_usage(usage) tru_usage = event_loop_metrics.accumulated_usage - exp_usage = Usage( - inputTokens=3, - outputTokens=6, - totalTokens=9, - ) + exp_usage = Usage(inputTokens=3, outputTokens=6, totalTokens=9, cacheWriteInputTokens=6) assert tru_usage == exp_usage mock_get_meter_provider.return_value.get_meter.assert_called() metrics_client = event_loop_metrics._metrics_client metrics_client.event_loop_input_tokens.record.assert_called() metrics_client.event_loop_output_tokens.record.assert_called() + metrics_client.event_loop_cache_write_input_tokens.record.assert_called() def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get_meter_provider): From c087f1883dcad7481de2499cb2d2d891c19e4ee7 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:06:45 +0800 Subject: [PATCH 040/104] fix: fix non-serializable parameter of agent from toolUse block (#568) * fix: fix non-serializable parameter of agent from toolUse block * feat: Add configuration option to MCP Client for server init timeout (#657) Co-authored-by: Harry Wilton * fix: Bedrock hang when exception occurs during message conversion (#643) Previously (#642) bedrock would hang during message conversion because the exception was not being caught and thus the queue was always empty. Now all exceptions during conversion are caught Co-authored-by: Mackenzie Zastrow * fix: only include parameters that defined in tool spec --------- Co-authored-by: Jack Yuan Co-authored-by: fhwilton55 <81768750+fhwilton55@users.noreply.github.com> Co-authored-by: Harry Wilton Co-authored-by: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Co-authored-by: Mackenzie Zastrow --- src/strands/agent/agent.py | 33 +++++++- tests/strands/agent/test_agent.py | 127 ++++++++---------------------- 2 files changed, 65 insertions(+), 95 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 38e687af2..acc6a7650 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -642,8 +642,11 @@ def _record_tool_execution( tool_result: The result returned by the tool. user_message_override: Optional custom message to include. """ + # Filter tool input parameters to only include those defined in tool spec + filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) + # Create user message describing the tool call - input_parameters = json.dumps(tool["input"], default=lambda o: f"<>") + input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") user_msg_content: list[ContentBlock] = [ {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} @@ -653,6 +656,13 @@ def _record_tool_execution( if user_message_override: user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) + # Create filtered tool use for message history + filtered_tool: ToolUse = { + "toolUseId": tool["toolUseId"], + "name": tool["name"], + "input": filtered_input, + } + # Create the message sequence user_msg: Message = { "role": "user", @@ -660,7 +670,7 @@ def _record_tool_execution( } tool_use_msg: Message = { "role": "assistant", - "content": [{"toolUse": tool}], + "content": [{"toolUse": filtered_tool}], } tool_result_msg: Message = { "role": "user", @@ -717,6 +727,25 @@ def _end_agent_trace_span( self.tracer.end_agent_span(**trace_attributes) + def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: + """Filter input parameters to only include those defined in the tool specification. + + Args: + tool_name: Name of the tool to get specification for + input_params: Original input parameters + + Returns: + Filtered parameters containing only those defined in tool spec + """ + all_tools_config = self.tool_registry.get_all_tools_config() + tool_spec = all_tools_config.get(tool_name) + + if not tool_spec or "inputSchema" not in tool_spec: + return input_params.copy() + + properties = tool_spec["inputSchema"]["json"]["properties"] + return {k: v for k, v in input_params.items() if k in properties} + def _append_message(self, message: Message) -> None: """Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent.""" self.messages.append(message) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index ca66ca2bf..444232455 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1738,99 +1738,7 @@ def test_agent_tool_non_serializable_parameter_filtering(agent, mock_randint): tool_call_text = user_message["content"][1]["text"] assert "agent.tool.tool_decorated direct tool call." in tool_call_text assert '"random_string": "test_value"' in tool_call_text - assert '"non_serializable_agent": "<>"' in tool_call_text - - -def test_agent_tool_multiple_non_serializable_types(agent, mock_randint): - """Test filtering of various non-serializable object types.""" - mock_randint.return_value = 123 - - # Create various non-serializable objects - class CustomClass: - def __init__(self, value): - self.value = value - - non_serializable_objects = { - "agent": Agent(), - "custom_object": CustomClass("test"), - "function": lambda x: x, - "set_object": {1, 2, 3}, - "complex_number": 3 + 4j, - "serializable_string": "this_should_remain", - "serializable_number": 42, - "serializable_list": [1, 2, 3], - "serializable_dict": {"key": "value"}, - } - - # This should not crash - result = agent.tool.tool_decorated(random_string="test_filtering", **non_serializable_objects) - - # Verify tool executed successfully - expected_result = { - "content": [{"text": "test_filtering"}], - "status": "success", - "toolUseId": "tooluse_tool_decorated_123", - } - assert result == expected_result - - # Check the recorded message for proper parameter filtering - assert len(agent.messages) > 0 - user_message = agent.messages[0] - tool_call_text = user_message["content"][0]["text"] - - # Verify serializable objects remain unchanged - assert '"serializable_string": "this_should_remain"' in tool_call_text - assert '"serializable_number": 42' in tool_call_text - assert '"serializable_list": [1, 2, 3]' in tool_call_text - assert '"serializable_dict": {"key": "value"}' in tool_call_text - - # Verify non-serializable objects are replaced with descriptive strings - assert '"agent": "<>"' in tool_call_text - assert ( - '"custom_object": "<.CustomClass>>"' - in tool_call_text - ) - assert '"function": "<>"' in tool_call_text - assert '"set_object": "<>"' in tool_call_text - assert '"complex_number": "<>"' in tool_call_text - - -def test_agent_tool_serialization_edge_cases(agent, mock_randint): - """Test edge cases in parameter serialization filtering.""" - mock_randint.return_value = 999 - - # Test with None values, empty containers, and nested structures - edge_case_params = { - "none_value": None, - "empty_list": [], - "empty_dict": {}, - "nested_list_with_non_serializable": [1, 2, Agent()], # This should be filtered out - "nested_dict_serializable": {"nested": {"key": "value"}}, # This should remain - } - - result = agent.tool.tool_decorated(random_string="edge_cases", **edge_case_params) - - # Verify successful execution - expected_result = { - "content": [{"text": "edge_cases"}], - "status": "success", - "toolUseId": "tooluse_tool_decorated_999", - } - assert result == expected_result - - # Check parameter filtering in recorded message - assert len(agent.messages) > 0 - user_message = agent.messages[0] - tool_call_text = user_message["content"][0]["text"] - - # Verify serializable values remain - assert '"none_value": null' in tool_call_text - assert '"empty_list": []' in tool_call_text - assert '"empty_dict": {}' in tool_call_text - assert '"nested_dict_serializable": {"nested": {"key": "value"}}' in tool_call_text - - # Verify non-serializable nested structure is replaced - assert '"nested_list_with_non_serializable": [1, 2, "<>"]' in tool_call_text + assert '"non_serializable_agent": "<>"' not in tool_call_text def test_agent_tool_no_non_serializable_parameters(agent, mock_randint): @@ -1882,3 +1790,36 @@ def test_agent_tool_record_direct_tool_call_disabled_with_non_serializable(agent # Verify no messages were recorded assert len(agent.messages) == 0 + + +def test_agent_tool_call_parameter_filtering_integration(mock_randint): + """Test that tool calls properly filter parameters in message recording.""" + mock_randint.return_value = 42 + + @strands.tool + def test_tool(action: str) -> str: + """Test tool with single parameter.""" + return action + + agent = Agent(tools=[test_tool]) + + # Call tool with extra non-spec parameters + result = agent.tool.test_tool( + action="test_value", + agent=agent, # Should be filtered out + extra_param="filtered", # Should be filtered out + ) + + # Verify tool executed successfully + assert result["status"] == "success" + assert result["content"] == [{"text": "test_value"}] + + # Check that only spec parameters are recorded in message history + assert len(agent.messages) > 0 + user_message = agent.messages[0] + tool_call_text = user_message["content"][0]["text"] + + # Should only contain the 'action' parameter + assert '"action": "test_value"' in tool_call_text + assert '"agent"' not in tool_call_text + assert '"extra_param"' not in tool_call_text From 17ccdd2df3ee7a213fe59a24d51a5ea238879117 Mon Sep 17 00:00:00 2001 From: vawsgit <147627358+vawsgit@users.noreply.github.com> Date: Tue, 19 Aug 2025 10:15:51 -0500 Subject: [PATCH 041/104] chore: add .DS_Store to .gitignore (#681) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index c27d1d902..888a96bbc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.DS_Store build __pycache__* .coverage* From ef18a255d5949b9ebbd46f08575ea881b5c64106 Mon Sep 17 00:00:00 2001 From: Jeremiah Date: Wed, 20 Aug 2025 13:21:02 -0400 Subject: [PATCH 042/104] feat(a2a): support A2A FileParts and DataParts (#596) Co-authored-by: jer --- src/strands/multiagent/a2a/executor.py | 185 +++- tests/strands/multiagent/a2a/test_executor.py | 787 +++++++++++++++++- 2 files changed, 947 insertions(+), 25 deletions(-) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 5bf9cbfe9..74ecc6531 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -8,18 +8,29 @@ streamed requests to the A2AServer. """ +import json import logging -from typing import Any +import mimetypes +from typing import Any, Literal from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue from a2a.server.tasks import TaskUpdater -from a2a.types import InternalError, Part, TaskState, TextPart, UnsupportedOperationError +from a2a.types import DataPart, FilePart, InternalError, Part, TaskState, TextPart, UnsupportedOperationError from a2a.utils import new_agent_text_message, new_task from a2a.utils.errors import ServerError from ...agent.agent import Agent as SAAgent from ...agent.agent import AgentResult as SAAgentResult +from ...types.content import ContentBlock +from ...types.media import ( + DocumentContent, + DocumentSource, + ImageContent, + ImageSource, + VideoContent, + VideoSource, +) logger = logging.getLogger(__name__) @@ -31,6 +42,12 @@ class StrandsA2AExecutor(AgentExecutor): and converts Strands Agent responses to A2A protocol events. """ + # Default formats for each file type when MIME type is unavailable or unrecognized + DEFAULT_FORMATS = {"document": "txt", "image": "png", "video": "mp4", "unknown": "txt"} + + # Handle special cases where format differs from extension + FORMAT_MAPPINGS = {"jpg": "jpeg", "htm": "html", "3gp": "three_gp", "3gpp": "three_gp", "3g2": "three_gp"} + def __init__(self, agent: SAAgent): """Initialize a StrandsA2AExecutor. @@ -78,10 +95,16 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater context: The A2A request context, containing the user's input and other metadata. updater: The task updater for managing task state and sending updates. """ - logger.info("Executing request in streaming mode") - user_input = context.get_user_input() + # Convert A2A message parts to Strands ContentBlocks + if context.message and hasattr(context.message, "parts"): + content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts) + if not content_blocks: + raise ValueError("No content blocks available") + else: + raise ValueError("No content blocks available") + try: - async for event in self.agent.stream_async(user_input): + async for event in self.agent.stream_async(content_blocks): await self._handle_streaming_event(event, updater) except Exception: logger.exception("Error in streaming execution") @@ -146,3 +169,155 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None """ logger.warning("Cancellation requested but not supported") raise ServerError(error=UnsupportedOperationError()) + + def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]: + """Classify file type based on MIME type. + + Args: + mime_type: The MIME type of the file + + Returns: + The classified file type + """ + if not mime_type: + return "unknown" + + mime_type = mime_type.lower() + + if mime_type.startswith("image/"): + return "image" + elif mime_type.startswith("video/"): + return "video" + elif ( + mime_type.startswith("text/") + or mime_type.startswith("application/") + or mime_type in ["application/pdf", "application/json", "application/xml"] + ): + return "document" + else: + return "unknown" + + def _get_file_format_from_mime_type(self, mime_type: str | None, file_type: str) -> str: + """Extract file format from MIME type using Python's mimetypes library. + + Args: + mime_type: The MIME type of the file + file_type: The classified file type (image, video, document, txt) + + Returns: + The file format string + """ + if not mime_type: + return self.DEFAULT_FORMATS.get(file_type, "txt") + + mime_type = mime_type.lower() + + # Extract subtype from MIME type and check existing format mappings + if "/" in mime_type: + subtype = mime_type.split("/")[-1] + if subtype in self.FORMAT_MAPPINGS: + return self.FORMAT_MAPPINGS[subtype] + + # Use mimetypes library to find extensions for the MIME type + extensions = mimetypes.guess_all_extensions(mime_type) + + if extensions: + extension = extensions[0][1:] # Remove the leading dot + return self.FORMAT_MAPPINGS.get(extension, extension) + + # Fallback to defaults for unknown MIME types + return self.DEFAULT_FORMATS.get(file_type, "txt") + + def _strip_file_extension(self, file_name: str) -> str: + """Strip the file extension from a file name. + + Args: + file_name: The original file name with extension + + Returns: + The file name without extension + """ + if "." in file_name: + return file_name.rsplit(".", 1)[0] + return file_name + + def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[ContentBlock]: + """Convert A2A message parts to Strands ContentBlocks. + + Args: + parts: List of A2A Part objects + + Returns: + List of Strands ContentBlock objects + """ + content_blocks: list[ContentBlock] = [] + + for part in parts: + try: + part_root = part.root + + if isinstance(part_root, TextPart): + # Handle TextPart + content_blocks.append(ContentBlock(text=part_root.text)) + + elif isinstance(part_root, FilePart): + # Handle FilePart + file_obj = part_root.file + mime_type = getattr(file_obj, "mime_type", None) + raw_file_name = getattr(file_obj, "name", "FileNameNotProvided") + file_name = self._strip_file_extension(raw_file_name) + file_type = self._get_file_type_from_mime_type(mime_type) + file_format = self._get_file_format_from_mime_type(mime_type, file_type) + + # Handle FileWithBytes vs FileWithUri + bytes_data = getattr(file_obj, "bytes", None) + uri_data = getattr(file_obj, "uri", None) + + if bytes_data: + if file_type == "image": + content_blocks.append( + ContentBlock( + image=ImageContent( + format=file_format, # type: ignore + source=ImageSource(bytes=bytes_data), + ) + ) + ) + elif file_type == "video": + content_blocks.append( + ContentBlock( + video=VideoContent( + format=file_format, # type: ignore + source=VideoSource(bytes=bytes_data), + ) + ) + ) + else: # document or unknown + content_blocks.append( + ContentBlock( + document=DocumentContent( + format=file_format, # type: ignore + name=file_name, + source=DocumentSource(bytes=bytes_data), + ) + ) + ) + # Handle FileWithUri + elif uri_data: + # For URI files, create a text representation since Strands ContentBlocks expect bytes + content_blocks.append( + ContentBlock( + text="[File: %s (%s)] - Referenced file at: %s" % (file_name, mime_type, uri_data) + ) + ) + elif isinstance(part_root, DataPart): + # Handle DataPart - convert structured data to JSON text + try: + data_text = json.dumps(part_root.data, indent=2) + content_blocks.append(ContentBlock(text="[Structured Data]\n%s" % data_text)) + except Exception: + logger.exception("Failed to serialize data part") + except Exception: + logger.exception("Error processing part") + + return content_blocks diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 77645fc73..3f63119f2 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -3,11 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from a2a.types import UnsupportedOperationError +from a2a.types import InternalError, UnsupportedOperationError from a2a.utils.errors import ServerError from strands.agent.agent_result import AgentResult as SAAgentResult from strands.multiagent.a2a.executor import StrandsA2AExecutor +from strands.types.content import ContentBlock def test_executor_initialization(mock_strands_agent): @@ -17,18 +18,304 @@ def test_executor_initialization(mock_strands_agent): assert executor.agent == mock_strands_agent +def test_classify_file_type(): + """Test file type classification based on MIME type.""" + executor = StrandsA2AExecutor(MagicMock()) + + # Test image types + assert executor._get_file_type_from_mime_type("image/jpeg") == "image" + assert executor._get_file_type_from_mime_type("image/png") == "image" + + # Test video types + assert executor._get_file_type_from_mime_type("video/mp4") == "video" + assert executor._get_file_type_from_mime_type("video/mpeg") == "video" + + # Test document types + assert executor._get_file_type_from_mime_type("text/plain") == "document" + assert executor._get_file_type_from_mime_type("application/pdf") == "document" + assert executor._get_file_type_from_mime_type("application/json") == "document" + + # Test unknown/edge cases + assert executor._get_file_type_from_mime_type("audio/mp3") == "unknown" + assert executor._get_file_type_from_mime_type(None) == "unknown" + assert executor._get_file_type_from_mime_type("") == "unknown" + + +def test_get_file_format_from_mime_type(): + """Test file format extraction from MIME type using mimetypes library.""" + executor = StrandsA2AExecutor(MagicMock()) + assert executor._get_file_format_from_mime_type("image/jpeg", "image") == "jpeg" + assert executor._get_file_format_from_mime_type("image/png", "image") == "png" + assert executor._get_file_format_from_mime_type("image/unknown", "image") == "png" + + # Test video formats + assert executor._get_file_format_from_mime_type("video/mp4", "video") == "mp4" + assert executor._get_file_format_from_mime_type("video/3gpp", "video") == "three_gp" + assert executor._get_file_format_from_mime_type("video/unknown", "video") == "mp4" + + # Test document formats + assert executor._get_file_format_from_mime_type("application/pdf", "document") == "pdf" + assert executor._get_file_format_from_mime_type("text/plain", "document") == "txt" + assert executor._get_file_format_from_mime_type("application/unknown", "document") == "txt" + + # Test None/empty cases + assert executor._get_file_format_from_mime_type(None, "image") == "png" + assert executor._get_file_format_from_mime_type("", "video") == "mp4" + + +def test_strip_file_extension(): + """Test file extension stripping.""" + executor = StrandsA2AExecutor(MagicMock()) + + assert executor._strip_file_extension("test.txt") == "test" + assert executor._strip_file_extension("document.pdf") == "document" + assert executor._strip_file_extension("image.jpeg") == "image" + assert executor._strip_file_extension("no_extension") == "no_extension" + assert executor._strip_file_extension("multiple.dots.file.ext") == "multiple.dots.file" + + +def test_convert_a2a_parts_to_content_blocks_text_part(): + """Test conversion of TextPart to ContentBlock.""" + from a2a.types import TextPart + + executor = StrandsA2AExecutor(MagicMock()) + + # Mock TextPart with proper spec + text_part = MagicMock(spec=TextPart) + text_part.text = "Hello, world!" + + # Mock Part with TextPart root + part = MagicMock() + part.root = text_part + + result = executor._convert_a2a_parts_to_content_blocks([part]) + + assert len(result) == 1 + assert result[0] == ContentBlock(text="Hello, world!") + + +def test_convert_a2a_parts_to_content_blocks_file_part_image_bytes(): + """Test conversion of FilePart with image bytes to ContentBlock.""" + from a2a.types import FilePart + + executor = StrandsA2AExecutor(MagicMock()) + + # Create test image bytes (no base64 encoding needed) + test_bytes = b"fake_image_data" + + # Mock file object + file_obj = MagicMock() + file_obj.name = "test_image.jpeg" + file_obj.mime_type = "image/jpeg" + file_obj.bytes = test_bytes + file_obj.uri = None + + # Mock FilePart with proper spec + file_part = MagicMock(spec=FilePart) + file_part.file = file_obj + + # Mock Part with FilePart root + part = MagicMock() + part.root = file_part + + result = executor._convert_a2a_parts_to_content_blocks([part]) + + assert len(result) == 1 + content_block = result[0] + assert "image" in content_block + assert content_block["image"]["format"] == "jpeg" + assert content_block["image"]["source"]["bytes"] == test_bytes + + +def test_convert_a2a_parts_to_content_blocks_file_part_video_bytes(): + """Test conversion of FilePart with video bytes to ContentBlock.""" + from a2a.types import FilePart + + executor = StrandsA2AExecutor(MagicMock()) + + # Create test video bytes (no base64 encoding needed) + test_bytes = b"fake_video_data" + + # Mock file object + file_obj = MagicMock() + file_obj.name = "test_video.mp4" + file_obj.mime_type = "video/mp4" + file_obj.bytes = test_bytes + file_obj.uri = None + + # Mock FilePart with proper spec + file_part = MagicMock(spec=FilePart) + file_part.file = file_obj + + # Mock Part with FilePart root + part = MagicMock() + part.root = file_part + + result = executor._convert_a2a_parts_to_content_blocks([part]) + + assert len(result) == 1 + content_block = result[0] + assert "video" in content_block + assert content_block["video"]["format"] == "mp4" + assert content_block["video"]["source"]["bytes"] == test_bytes + + +def test_convert_a2a_parts_to_content_blocks_file_part_document_bytes(): + """Test conversion of FilePart with document bytes to ContentBlock.""" + from a2a.types import FilePart + + executor = StrandsA2AExecutor(MagicMock()) + + # Create test document bytes (no base64 encoding needed) + test_bytes = b"fake_document_data" + + # Mock file object + file_obj = MagicMock() + file_obj.name = "test_document.pdf" + file_obj.mime_type = "application/pdf" + file_obj.bytes = test_bytes + file_obj.uri = None + + # Mock FilePart with proper spec + file_part = MagicMock(spec=FilePart) + file_part.file = file_obj + + # Mock Part with FilePart root + part = MagicMock() + part.root = file_part + + result = executor._convert_a2a_parts_to_content_blocks([part]) + + assert len(result) == 1 + content_block = result[0] + assert "document" in content_block + assert content_block["document"]["format"] == "pdf" + assert content_block["document"]["name"] == "test_document" + assert content_block["document"]["source"]["bytes"] == test_bytes + + +def test_convert_a2a_parts_to_content_blocks_file_part_uri(): + """Test conversion of FilePart with URI to ContentBlock.""" + from a2a.types import FilePart + + executor = StrandsA2AExecutor(MagicMock()) + + # Mock file object with URI + file_obj = MagicMock() + file_obj.name = "test_image.png" + file_obj.mime_type = "image/png" + file_obj.bytes = None + file_obj.uri = "https://example.com/image.png" + + # Mock FilePart with proper spec + file_part = MagicMock(spec=FilePart) + file_part.file = file_obj + + # Mock Part with FilePart root + part = MagicMock() + part.root = file_part + + result = executor._convert_a2a_parts_to_content_blocks([part]) + + assert len(result) == 1 + content_block = result[0] + assert "text" in content_block + assert "test_image" in content_block["text"] + assert "https://example.com/image.png" in content_block["text"] + + +def test_convert_a2a_parts_to_content_blocks_file_part_with_bytes(): + """Test conversion of FilePart with bytes data.""" + from a2a.types import FilePart + + executor = StrandsA2AExecutor(MagicMock()) + + # Mock file object with bytes (no validation needed since no decoding) + file_obj = MagicMock() + file_obj.name = "test_image.png" + file_obj.mime_type = "image/png" + file_obj.bytes = b"some_binary_data" + file_obj.uri = None + + # Mock FilePart with proper spec + file_part = MagicMock(spec=FilePart) + file_part.file = file_obj + + # Mock Part with FilePart root + part = MagicMock() + part.root = file_part + + result = executor._convert_a2a_parts_to_content_blocks([part]) + + assert len(result) == 1 + content_block = result[0] + assert "image" in content_block + assert content_block["image"]["source"]["bytes"] == b"some_binary_data" + + +def test_convert_a2a_parts_to_content_blocks_data_part(): + """Test conversion of DataPart to ContentBlock.""" + from a2a.types import DataPart + + executor = StrandsA2AExecutor(MagicMock()) + + # Mock DataPart with proper spec + test_data = {"key": "value", "number": 42} + data_part = MagicMock(spec=DataPart) + data_part.data = test_data + + # Mock Part with DataPart root + part = MagicMock() + part.root = data_part + + result = executor._convert_a2a_parts_to_content_blocks([part]) + + assert len(result) == 1 + content_block = result[0] + assert "text" in content_block + assert "[Structured Data]" in content_block["text"] + assert "key" in content_block["text"] + assert "value" in content_block["text"] + + +def test_convert_a2a_parts_to_content_blocks_mixed_parts(): + """Test conversion of mixed A2A parts to ContentBlocks.""" + from a2a.types import DataPart, TextPart + + executor = StrandsA2AExecutor(MagicMock()) + + # Mock TextPart with proper spec + text_part = MagicMock(spec=TextPart) + text_part.text = "Text content" + text_part_mock = MagicMock() + text_part_mock.root = text_part + + # Mock DataPart with proper spec + data_part = MagicMock(spec=DataPart) + data_part.data = {"test": "data"} + data_part_mock = MagicMock() + data_part_mock.root = data_part + + parts = [text_part_mock, data_part_mock] + result = executor._convert_a2a_parts_to_content_blocks(parts) + + assert len(result) == 2 + assert result[0]["text"] == "Text content" + assert "[Structured Data]" in result[1]["text"] + + @pytest.mark.asyncio async def test_execute_streaming_mode_with_data_events(mock_strands_agent, mock_request_context, mock_event_queue): """Test that execute processes data events correctly in streaming mode.""" - async def mock_stream(user_input): + async def mock_stream(content_blocks): """Mock streaming function that yields data events.""" yield {"data": "First chunk"} yield {"data": "Second chunk"} yield {"result": MagicMock(spec=SAAgentResult)} # Setup mock agent streaming - mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input")) + mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([])) # Create executor executor = StrandsA2AExecutor(mock_strands_agent) @@ -39,10 +326,25 @@ async def mock_stream(user_input): mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task + # Mock message with parts + from a2a.types import TextPart + + mock_message = MagicMock() + text_part = MagicMock(spec=TextPart) + text_part.text = "Test input" + part = MagicMock() + part.root = text_part + mock_message.parts = [part] + mock_request_context.message = mock_message + await executor.execute(mock_request_context, mock_event_queue) - # Verify agent was called with correct input - mock_strands_agent.stream_async.assert_called_once_with("Test input") + # Verify agent was called with ContentBlock list + mock_strands_agent.stream_async.assert_called_once() + call_args = mock_strands_agent.stream_async.call_args[0][0] + assert isinstance(call_args, list) + assert len(call_args) == 1 + assert call_args[0]["text"] == "Test input" # Verify events were enqueued mock_event_queue.enqueue_event.assert_called() @@ -52,12 +354,12 @@ async def mock_stream(user_input): async def test_execute_streaming_mode_with_result_event(mock_strands_agent, mock_request_context, mock_event_queue): """Test that execute processes result events correctly in streaming mode.""" - async def mock_stream(user_input): + async def mock_stream(content_blocks): """Mock streaming function that yields only result event.""" yield {"result": MagicMock(spec=SAAgentResult)} # Setup mock agent streaming - mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input")) + mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([])) # Create executor executor = StrandsA2AExecutor(mock_strands_agent) @@ -68,10 +370,25 @@ async def mock_stream(user_input): mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task + # Mock message with parts + from a2a.types import TextPart + + mock_message = MagicMock() + text_part = MagicMock(spec=TextPart) + text_part.text = "Test input" + part = MagicMock() + part.root = text_part + mock_message.parts = [part] + mock_request_context.message = mock_message + await executor.execute(mock_request_context, mock_event_queue) - # Verify agent was called with correct input - mock_strands_agent.stream_async.assert_called_once_with("Test input") + # Verify agent was called with ContentBlock list + mock_strands_agent.stream_async.assert_called_once() + call_args = mock_strands_agent.stream_async.call_args[0][0] + assert isinstance(call_args, list) + assert len(call_args) == 1 + assert call_args[0]["text"] == "Test input" # Verify events were enqueued mock_event_queue.enqueue_event.assert_called() @@ -81,13 +398,13 @@ async def mock_stream(user_input): async def test_execute_streaming_mode_with_empty_data(mock_strands_agent, mock_request_context, mock_event_queue): """Test that execute handles empty data events correctly in streaming mode.""" - async def mock_stream(user_input): + async def mock_stream(content_blocks): """Mock streaming function that yields empty data.""" yield {"data": ""} yield {"result": MagicMock(spec=SAAgentResult)} # Setup mock agent streaming - mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input")) + mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([])) # Create executor executor = StrandsA2AExecutor(mock_strands_agent) @@ -98,10 +415,25 @@ async def mock_stream(user_input): mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task + # Mock message with parts + from a2a.types import TextPart + + mock_message = MagicMock() + text_part = MagicMock(spec=TextPart) + text_part.text = "Test input" + part = MagicMock() + part.root = text_part + mock_message.parts = [part] + mock_request_context.message = mock_message + await executor.execute(mock_request_context, mock_event_queue) - # Verify agent was called with correct input - mock_strands_agent.stream_async.assert_called_once_with("Test input") + # Verify agent was called with ContentBlock list + mock_strands_agent.stream_async.assert_called_once() + call_args = mock_strands_agent.stream_async.call_args[0][0] + assert isinstance(call_args, list) + assert len(call_args) == 1 + assert call_args[0]["text"] == "Test input" # Verify events were enqueued mock_event_queue.enqueue_event.assert_called() @@ -111,13 +443,13 @@ async def mock_stream(user_input): async def test_execute_streaming_mode_with_unexpected_event(mock_strands_agent, mock_request_context, mock_event_queue): """Test that execute handles unexpected events correctly in streaming mode.""" - async def mock_stream(user_input): + async def mock_stream(content_blocks): """Mock streaming function that yields unexpected event.""" yield {"unexpected": "event"} yield {"result": MagicMock(spec=SAAgentResult)} # Setup mock agent streaming - mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input")) + mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([])) # Create executor executor = StrandsA2AExecutor(mock_strands_agent) @@ -128,26 +460,69 @@ async def mock_stream(user_input): mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task + # Mock message with parts + from a2a.types import TextPart + + mock_message = MagicMock() + text_part = MagicMock(spec=TextPart) + text_part.text = "Test input" + part = MagicMock() + part.root = text_part + mock_message.parts = [part] + mock_request_context.message = mock_message + await executor.execute(mock_request_context, mock_event_queue) - # Verify agent was called with correct input - mock_strands_agent.stream_async.assert_called_once_with("Test input") + # Verify agent was called with ContentBlock list + mock_strands_agent.stream_async.assert_called_once() + call_args = mock_strands_agent.stream_async.call_args[0][0] + assert isinstance(call_args, list) + assert len(call_args) == 1 + assert call_args[0]["text"] == "Test input" # Verify events were enqueued mock_event_queue.enqueue_event.assert_called() +@pytest.mark.asyncio +async def test_execute_streaming_mode_fallback_to_text_extraction( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test that execute raises ServerError when no A2A parts are available.""" + + # Create executor + executor = StrandsA2AExecutor(mock_strands_agent) + + # Mock the task creation + mock_task = MagicMock() + mock_task.id = "test-task-id" + mock_task.context_id = "test-context-id" + mock_request_context.current_task = mock_task + + # Mock message without parts attribute + mock_message = MagicMock() + delattr(mock_message, "parts") # Remove parts attribute + mock_request_context.message = mock_message + mock_request_context.get_user_input.return_value = "Fallback input" + + with pytest.raises(ServerError) as excinfo: + await executor.execute(mock_request_context, mock_event_queue) + + # Verify the error is a ServerError containing an InternalError + assert isinstance(excinfo.value.error, InternalError) + + @pytest.mark.asyncio async def test_execute_creates_task_when_none_exists(mock_strands_agent, mock_request_context, mock_event_queue): """Test that execute creates a new task when none exists.""" - async def mock_stream(user_input): + async def mock_stream(content_blocks): """Mock streaming function that yields data events.""" yield {"data": "Test chunk"} yield {"result": MagicMock(spec=SAAgentResult)} # Setup mock agent streaming - mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input")) + mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([])) # Create executor executor = StrandsA2AExecutor(mock_strands_agent) @@ -155,6 +530,17 @@ async def mock_stream(user_input): # Mock no existing task mock_request_context.current_task = None + # Mock message with parts + from a2a.types import TextPart + + mock_message = MagicMock() + text_part = MagicMock(spec=TextPart) + text_part.text = "Test input" + part = MagicMock() + part.root = text_part + mock_message.parts = [part] + mock_request_context.message = mock_message + with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task: mock_new_task.return_value = MagicMock(id="new-task-id", context_id="new-context-id") @@ -183,11 +569,22 @@ async def test_execute_streaming_mode_handles_agent_exception( mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task + # Mock message with parts + from a2a.types import TextPart + + mock_message = MagicMock() + text_part = MagicMock(spec=TextPart) + text_part.text = "Test input" + part = MagicMock() + part.root = text_part + mock_message.parts = [part] + mock_request_context.message = mock_message + with pytest.raises(ServerError): await executor.execute(mock_request_context, mock_event_queue) # Verify agent was called - mock_strands_agent.stream_async.assert_called_once_with("Test input") + mock_strands_agent.stream_async.assert_called_once() @pytest.mark.asyncio @@ -252,3 +649,353 @@ async def test_handle_agent_result_with_result_but_no_message( # Verify completion was called mock_updater.complete.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_agent_result_with_content(mock_strands_agent): + """Test that _handle_agent_result handles result with content correctly.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + # Mock TaskUpdater + mock_updater = MagicMock() + mock_updater.complete = AsyncMock() + mock_updater.add_artifact = AsyncMock() + + # Create result with content + mock_result = MagicMock(spec=SAAgentResult) + mock_result.__str__ = MagicMock(return_value="Test response content") + + # Call _handle_agent_result + await executor._handle_agent_result(mock_result, mock_updater) + + # Verify artifact was added and task completed + mock_updater.add_artifact.assert_called_once() + mock_updater.complete.assert_called_once() + + # Check that the artifact contains the expected content + call_args = mock_updater.add_artifact.call_args[0][0] + assert len(call_args) == 1 + assert call_args[0].root.text == "Test response content" + + +def test_handle_conversion_error(): + """Test that conversion handles errors gracefully.""" + executor = StrandsA2AExecutor(MagicMock()) + + # Mock Part that will raise an exception during processing + problematic_part = MagicMock() + problematic_part.root = None # This should cause an AttributeError + + # Should not raise an exception, but return empty list or handle gracefully + result = executor._convert_a2a_parts_to_content_blocks([problematic_part]) + + # The method should handle the error and continue + assert isinstance(result, list) + + +def test_convert_a2a_parts_to_content_blocks_empty_list(): + """Test conversion with empty parts list.""" + executor = StrandsA2AExecutor(MagicMock()) + + result = executor._convert_a2a_parts_to_content_blocks([]) + + assert result == [] + + +def test_convert_a2a_parts_to_content_blocks_file_part_no_name(): + """Test conversion of FilePart with no file name.""" + from a2a.types import FilePart + + executor = StrandsA2AExecutor(MagicMock()) + + # Mock file object without name + file_obj = MagicMock() + delattr(file_obj, "name") # Remove name attribute + file_obj.mime_type = "text/plain" + file_obj.bytes = b"test content" + file_obj.uri = None + + # Mock FilePart with proper spec + file_part = MagicMock(spec=FilePart) + file_part.file = file_obj + + # Mock Part with FilePart root + part = MagicMock() + part.root = file_part + + result = executor._convert_a2a_parts_to_content_blocks([part]) + + assert len(result) == 1 + content_block = result[0] + assert "document" in content_block + assert content_block["document"]["name"] == "FileNameNotProvided" # Should use default + + +def test_convert_a2a_parts_to_content_blocks_file_part_no_mime_type(): + """Test conversion of FilePart with no MIME type.""" + from a2a.types import FilePart + + executor = StrandsA2AExecutor(MagicMock()) + + # Mock file object without MIME type + file_obj = MagicMock() + file_obj.name = "test_file" + delattr(file_obj, "mime_type") + file_obj.bytes = b"test content" + file_obj.uri = None + + # Mock FilePart with proper spec + file_part = MagicMock(spec=FilePart) + file_part.file = file_obj + + # Mock Part with FilePart root + part = MagicMock() + part.root = file_part + + result = executor._convert_a2a_parts_to_content_blocks([part]) + + assert len(result) == 1 + content_block = result[0] + assert "document" in content_block # Should default to document with unknown type + assert content_block["document"]["format"] == "txt" # Should use default format for unknown file type + + +def test_convert_a2a_parts_to_content_blocks_file_part_no_bytes_no_uri(): + """Test conversion of FilePart with neither bytes nor URI.""" + from a2a.types import FilePart + + executor = StrandsA2AExecutor(MagicMock()) + + # Mock file object without bytes or URI + file_obj = MagicMock() + file_obj.name = "test_file.txt" + file_obj.mime_type = "text/plain" + file_obj.bytes = None + file_obj.uri = None + + # Mock FilePart with proper spec + file_part = MagicMock(spec=FilePart) + file_part.file = file_obj + + # Mock Part with FilePart root + part = MagicMock() + part.root = file_part + + result = executor._convert_a2a_parts_to_content_blocks([part]) + + # Should return empty list since no fallback case exists + assert len(result) == 0 + + +def test_convert_a2a_parts_to_content_blocks_data_part_serialization_error(): + """Test conversion of DataPart with non-serializable data.""" + from a2a.types import DataPart + + executor = StrandsA2AExecutor(MagicMock()) + + # Create non-serializable data (e.g., a function) + def non_serializable(): + pass + + # Mock DataPart with proper spec + data_part = MagicMock(spec=DataPart) + data_part.data = {"function": non_serializable} # This will cause JSON serialization to fail + + # Mock Part with DataPart root + part = MagicMock() + part.root = data_part + + # Should not raise an exception, should handle gracefully + result = executor._convert_a2a_parts_to_content_blocks([part]) + + # The error handling should result in an empty list or the part being skipped + assert isinstance(result, list) + + +@pytest.mark.asyncio +async def test_execute_streaming_mode_raises_error_for_empty_content_blocks( + mock_strands_agent, mock_event_queue, mock_request_context +): + """Test that execute raises ServerError when content blocks are empty after conversion.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + # Create a mock message with parts that will result in empty content blocks + # This could happen if all parts fail to convert or are invalid + mock_message = MagicMock() + mock_message.parts = [MagicMock()] # Has parts but they won't convert to valid content blocks + mock_request_context.message = mock_message + + # Mock the conversion to return empty list + with patch.object(executor, "_convert_a2a_parts_to_content_blocks", return_value=[]): + with pytest.raises(ServerError) as excinfo: + await executor.execute(mock_request_context, mock_event_queue) + + # Verify the error is a ServerError containing an InternalError + assert isinstance(excinfo.value.error, InternalError) + + +@pytest.mark.asyncio +async def test_execute_with_mixed_part_types(mock_strands_agent, mock_request_context, mock_event_queue): + """Test execute with a message containing mixed A2A part types.""" + from a2a.types import DataPart, FilePart, TextPart + + async def mock_stream(content_blocks): + """Mock streaming function.""" + yield {"data": "Processing mixed content"} + yield {"result": MagicMock(spec=SAAgentResult)} + + # Setup mock agent streaming + mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([])) + + # Create executor + executor = StrandsA2AExecutor(mock_strands_agent) + + # Mock the task creation + mock_task = MagicMock() + mock_task.id = "test-task-id" + mock_task.context_id = "test-context-id" + mock_request_context.current_task = mock_task + + # Create mixed parts + text_part = MagicMock(spec=TextPart) + text_part.text = "Hello" + text_part_mock = MagicMock() + text_part_mock.root = text_part + + # File part with bytes + file_obj = MagicMock() + file_obj.name = "image.png" + file_obj.mime_type = "image/png" + file_obj.bytes = b"fake_image" + file_obj.uri = None + file_part = MagicMock(spec=FilePart) + file_part.file = file_obj + file_part_mock = MagicMock() + file_part_mock.root = file_part + + # Data part + data_part = MagicMock(spec=DataPart) + data_part.data = {"key": "value"} + data_part_mock = MagicMock() + data_part_mock.root = data_part + + # Mock message with mixed parts + mock_message = MagicMock() + mock_message.parts = [text_part_mock, file_part_mock, data_part_mock] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Verify agent was called with ContentBlock list containing all types + mock_strands_agent.stream_async.assert_called_once() + call_args = mock_strands_agent.stream_async.call_args[0][0] + assert isinstance(call_args, list) + assert len(call_args) == 3 # Should have converted all 3 parts + + # Check that we have text, image, and structured data + has_text = any("text" in block for block in call_args) + has_image = any("image" in block for block in call_args) + has_structured_data = any("text" in block and "[Structured Data]" in block.get("text", "") for block in call_args) + + assert has_text + assert has_image + assert has_structured_data + + +def test_integration_example(): + """Integration test example showing how A2A Parts are converted to ContentBlocks. + + This test serves as documentation for the conversion functionality. + """ + from a2a.types import DataPart, FilePart, TextPart + + executor = StrandsA2AExecutor(MagicMock()) + + # Example 1: Text content + text_part = MagicMock(spec=TextPart) + text_part.text = "Hello, this is a text message" + text_part_mock = MagicMock() + text_part_mock.root = text_part + + # Example 2: Image file + image_bytes = b"fake_image_content" + image_file = MagicMock() + image_file.name = "photo.jpg" + image_file.mime_type = "image/jpeg" + image_file.bytes = image_bytes + image_file.uri = None + + image_part = MagicMock(spec=FilePart) + image_part.file = image_file + image_part_mock = MagicMock() + image_part_mock.root = image_part + + # Example 3: Document file + doc_bytes = b"PDF document content" + doc_file = MagicMock() + doc_file.name = "report.pdf" + doc_file.mime_type = "application/pdf" + doc_file.bytes = doc_bytes + doc_file.uri = None + + doc_part = MagicMock(spec=FilePart) + doc_part.file = doc_file + doc_part_mock = MagicMock() + doc_part_mock.root = doc_part + + # Example 4: Structured data + data_part = MagicMock(spec=DataPart) + data_part.data = {"user": "john_doe", "action": "upload_file", "timestamp": "2023-12-01T10:00:00Z"} + data_part_mock = MagicMock() + data_part_mock.root = data_part + + # Convert all parts to ContentBlocks + parts = [text_part_mock, image_part_mock, doc_part_mock, data_part_mock] + content_blocks = executor._convert_a2a_parts_to_content_blocks(parts) + + # Verify conversion results + assert len(content_blocks) == 4 + + # Text part becomes text ContentBlock + assert content_blocks[0]["text"] == "Hello, this is a text message" + + # Image part becomes image ContentBlock with proper format and bytes + assert "image" in content_blocks[1] + assert content_blocks[1]["image"]["format"] == "jpeg" + assert content_blocks[1]["image"]["source"]["bytes"] == image_bytes + + # Document part becomes document ContentBlock + assert "document" in content_blocks[2] + assert content_blocks[2]["document"]["format"] == "pdf" + assert content_blocks[2]["document"]["name"] == "report" # Extension stripped + assert content_blocks[2]["document"]["source"]["bytes"] == doc_bytes + + # Data part becomes text ContentBlock with JSON representation + assert "text" in content_blocks[3] + assert "[Structured Data]" in content_blocks[3]["text"] + assert "john_doe" in content_blocks[3]["text"] + assert "upload_file" in content_blocks[3]["text"] + + +def test_default_formats_modularization(): + """Test that DEFAULT_FORMATS mapping works correctly for modular format defaults.""" + executor = StrandsA2AExecutor(MagicMock()) + + # Test that DEFAULT_FORMATS contains expected mappings + assert hasattr(executor, "DEFAULT_FORMATS") + assert executor.DEFAULT_FORMATS["document"] == "txt" + assert executor.DEFAULT_FORMATS["image"] == "png" + assert executor.DEFAULT_FORMATS["video"] == "mp4" + assert executor.DEFAULT_FORMATS["unknown"] == "txt" + + # Test format selection with None mime_type + assert executor._get_file_format_from_mime_type(None, "document") == "txt" + assert executor._get_file_format_from_mime_type(None, "image") == "png" + assert executor._get_file_format_from_mime_type(None, "video") == "mp4" + assert executor._get_file_format_from_mime_type(None, "unknown") == "txt" + assert executor._get_file_format_from_mime_type(None, "nonexistent") == "txt" # fallback + + # Test format selection with empty mime_type + assert executor._get_file_format_from_mime_type("", "document") == "txt" + assert executor._get_file_format_from_mime_type("", "image") == "png" + assert executor._get_file_format_from_mime_type("", "video") == "mp4" From 60dcb454c550002379444c698867b3f5e49fd490 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 20 Aug 2025 15:17:59 -0400 Subject: [PATCH 043/104] ci: update pre-commit requirement from <4.2.0,>=3.2.0 to >=3.2.0,<4.4.0 (#706) Updates the requirements on [pre-commit](https://github.com/pre-commit/pre-commit) to permit the latest version. - [Release notes](https://github.com/pre-commit/pre-commit/releases) - [Changelog](https://github.com/pre-commit/pre-commit/blob/main/CHANGELOG.md) - [Commits](https://github.com/pre-commit/pre-commit/compare/v3.2.0...v4.3.0) --- updated-dependencies: - dependency-name: pre-commit dependency-version: 4.3.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 847db8d2b..de28c311c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ dev = [ "hatch>=1.0.0,<2.0.0", "moto>=5.1.0,<6.0.0", "mypy>=1.15.0,<2.0.0", - "pre-commit>=3.2.0,<4.2.0", + "pre-commit>=3.2.0,<4.4.0", "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", "pytest-cov>=4.1.0,<5.0.0", From b61a06416b250693f162cd490b941643cdbefbc5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 21 Aug 2025 09:39:50 -0400 Subject: [PATCH 044/104] ci: update ruff requirement from <0.5.0,>=0.4.4 to >=0.4.4,<0.13.0 (#704) * ci: update ruff requirement from <0.5.0,>=0.4.4 to >=0.4.4,<0.13.0 Updates the requirements on [ruff](https://github.com/astral-sh/ruff) to permit the latest version. - [Release notes](https://github.com/astral-sh/ruff/releases) - [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md) - [Commits](https://github.com/astral-sh/ruff/compare/v0.4.4...0.12.9) --- updated-dependencies: - dependency-name: ruff dependency-version: 0.12.9 dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Apply suggestions from code review Co-authored-by: Patrick Gray --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jonathan Segev Co-authored-by: Patrick Gray --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index de28c311c..124ba5653 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ dev = [ "pytest-asyncio>=0.26.0,<0.27.0", "pytest-cov>=4.1.0,<5.0.0", "pytest-xdist>=3.0.0,<4.0.0", - "ruff>=0.4.4,<0.5.0", + "ruff>=0.12.0,<0.13.0", ] docs = [ "sphinx>=5.0.0,<6.0.0", From 93d3ac83573d6085e02b165541b55c8da3d10bce Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 21 Aug 2025 09:40:01 -0400 Subject: [PATCH 045/104] ci: update pytest-asyncio requirement from <0.27.0,>=0.26.0 to >=0.26.0,<1.2.0 (#708) * ci: update pytest-asyncio requirement Updates the requirements on [pytest-asyncio](https://github.com/pytest-dev/pytest-asyncio) to permit the latest version. - [Release notes](https://github.com/pytest-dev/pytest-asyncio/releases) - [Commits](https://github.com/pytest-dev/pytest-asyncio/compare/v0.26.0...v1.1.0) --- updated-dependencies: - dependency-name: pytest-asyncio dependency-version: 1.1.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Apply suggestions from code review Co-authored-by: Patrick Gray --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jonathan Segev Co-authored-by: Patrick Gray --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 124ba5653..f91454414 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dev = [ "mypy>=1.15.0,<2.0.0", "pre-commit>=3.2.0,<4.4.0", "pytest>=8.0.0,<9.0.0", - "pytest-asyncio>=0.26.0,<0.27.0", + "pytest-asyncio>=1.0.0,<1.2.0", "pytest-cov>=4.1.0,<5.0.0", "pytest-xdist>=3.0.0,<4.0.0", "ruff>=0.12.0,<0.13.0", @@ -143,7 +143,7 @@ features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mis extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", - "pytest-asyncio>=0.26.0,<0.27.0", + "pytest-asyncio>=1.0.0,<1.2.0", "pytest-cov>=4.1.0,<5.0.0", "pytest-xdist>=3.0.0,<4.0.0", ] From 9397f58a953b83a7190e686ac6e29fa6d4e8ac86 Mon Sep 17 00:00:00 2001 From: Xwei Date: Thu, 21 Aug 2025 22:19:13 +0800 Subject: [PATCH 046/104] fix: add system_prompt to structured_output_span before adding input_messages (#709) * fix: add system_prompt to structured_output_span before adding input_messages * test: Add system message ordering validation to agent structured output test * Switch to ensuring exact ordering of messages --------- Co-authored-by: Dennis Tsai (RD-AS) Co-authored-by: Mackenzie Zastrow --- src/strands/agent/agent.py | 10 +++++----- tests/strands/agent/test_agent.py | 25 +++++++++++++++++-------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index acc6a7650..5150060c6 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -470,16 +470,16 @@ async def structured_output_async( "gen_ai.operation.name": "execute_structured_output", } ) - for message in temp_messages: - structured_output_span.add_event( - f"gen_ai.{message['role']}.message", - attributes={"role": message["role"], "content": serialize(message["content"])}, - ) if self.system_prompt: structured_output_span.add_event( "gen_ai.system.message", attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])}, ) + for message in temp_messages: + structured_output_span.add_event( + f"gen_ai.{message['role']}.message", + attributes={"role": message["role"], "content": serialize(message["content"])}, + ) events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) async for event in events: if "callback" in event: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 444232455..7e769c6d7 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -18,6 +18,7 @@ from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager +from strands.telemetry.tracer import serialize from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType @@ -1028,15 +1029,23 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): } ) - mock_span.add_event.assert_any_call( - "gen_ai.user.message", - attributes={"role": "user", "content": '[{"text": "Jane Doe is 30 years old and her email is jane@doe.com"}]'}, - ) + # ensure correct otel event messages are emitted + act_event_names = mock_span.add_event.call_args_list + exp_event_names = [ + unittest.mock.call( + "gen_ai.system.message", attributes={"role": "system", "content": serialize([{"text": system_prompt}])} + ), + unittest.mock.call( + "gen_ai.user.message", + attributes={ + "role": "user", + "content": '[{"text": "Jane Doe is 30 years old and her email is jane@doe.com"}]', + }, + ), + unittest.mock.call("gen_ai.choice", attributes={"message": json.dumps(user.model_dump())}), + ] - mock_span.add_event.assert_called_with( - "gen_ai.choice", - attributes={"message": json.dumps(user.model_dump())}, - ) + assert act_event_names == exp_event_names def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, agenerator): From 6ef64478d7fde3c677ea13cadf068422a3d01377 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 25 Aug 2025 16:16:50 +0300 Subject: [PATCH 047/104] feat(multiagent): Add __call__ implementation to MultiAgentBase (#645) --- src/strands/multiagent/base.py | 11 +++++++-- tests/strands/multiagent/test_base.py | 34 +++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index c6b1af702..69578cb5d 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -3,7 +3,9 @@ Provides minimal foundation for multi-agent patterns (Swarm, Graph). """ +import asyncio from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum from typing import Any, Union @@ -86,7 +88,12 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> M """Invoke asynchronously.""" raise NotImplementedError("invoke_async not implemented") - @abstractmethod def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult: """Invoke synchronously.""" - raise NotImplementedError("__call__ not implemented") + + def execute() -> MultiAgentResult: + return asyncio.run(self.invoke_async(task, **kwargs)) + + with ThreadPoolExecutor() as executor: + future = executor.submit(execute) + return future.result() diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 7aa76bb90..395d9275c 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -141,9 +141,35 @@ class CompleteMultiAgent(MultiAgentBase): async def invoke_async(self, task: str) -> MultiAgentResult: return MultiAgentResult(results={}) - def __call__(self, task: str) -> MultiAgentResult: - return MultiAgentResult(results={}) - - # Should not raise an exception + # Should not raise an exception - __call__ is provided by base class agent = CompleteMultiAgent() assert isinstance(agent, MultiAgentBase) + + +def test_multi_agent_base_call_method(): + """Test that __call__ method properly delegates to invoke_async.""" + + class TestMultiAgent(MultiAgentBase): + def __init__(self): + self.invoke_async_called = False + self.received_task = None + self.received_kwargs = None + + async def invoke_async(self, task, **kwargs): + self.invoke_async_called = True + self.received_task = task + self.received_kwargs = kwargs + return MultiAgentResult( + status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)} + ) + + agent = TestMultiAgent() + + # Test with string task + result = agent("test task", param1="value1", param2="value2") + + assert agent.invoke_async_called + assert agent.received_task == "test task" + assert agent.received_kwargs == {"param1": "value1", "param2": "value2"} + assert isinstance(result, MultiAgentResult) + assert result.status == Status.COMPLETED From e4879e18121985b860d0f9e3556c0bf7e512a4a7 Mon Sep 17 00:00:00 2001 From: mehtarac Date: Mon, 25 Aug 2025 06:40:26 -0700 Subject: [PATCH 048/104] chore: Update pydantic minimum version (#723) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f91454414..32de94aa6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "botocore>=1.29.0,<2.0.0", "docstring_parser>=0.15,<1.0", "mcp>=1.11.0,<2.0.0", - "pydantic>=2.0.0,<3.0.0", + "pydantic>=2.4.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", "watchdog>=6.0.0,<7.0.0", "opentelemetry-api>=1.30.0,<2.0.0", From c18ef930ee7c436f7af58845001a2e02014b52da Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 25 Aug 2025 10:16:04 -0400 Subject: [PATCH 049/104] tool executors (#658) --- src/strands/agent/agent.py | 15 +- src/strands/event_loop/event_loop.py | 155 +----- src/strands/tools/_validator.py | 45 ++ src/strands/tools/executor.py | 137 ------ src/strands/tools/executors/__init__.py | 16 + src/strands/tools/executors/_executor.py | 227 +++++++++ src/strands/tools/executors/concurrent.py | 113 +++++ src/strands/tools/executors/sequential.py | 46 ++ tests/strands/agent/test_agent.py | 44 +- tests/strands/event_loop/test_event_loop.py | 271 +---------- tests/strands/tools/executors/conftest.py | 116 +++++ .../tools/executors/test_concurrent.py | 32 ++ .../strands/tools/executors/test_executor.py | 144 ++++++ .../tools/executors/test_sequential.py | 32 ++ tests/strands/tools/test_executor.py | 440 ------------------ tests/strands/tools/test_validator.py | 50 ++ .../tools/executors/test_concurrent.py | 61 +++ .../tools/executors/test_sequential.py | 61 +++ 18 files changed, 985 insertions(+), 1020 deletions(-) create mode 100644 src/strands/tools/_validator.py delete mode 100644 src/strands/tools/executor.py create mode 100644 src/strands/tools/executors/__init__.py create mode 100644 src/strands/tools/executors/_executor.py create mode 100644 src/strands/tools/executors/concurrent.py create mode 100644 src/strands/tools/executors/sequential.py create mode 100644 tests/strands/tools/executors/conftest.py create mode 100644 tests/strands/tools/executors/test_concurrent.py create mode 100644 tests/strands/tools/executors/test_executor.py create mode 100644 tests/strands/tools/executors/test_sequential.py delete mode 100644 tests/strands/tools/test_executor.py create mode 100644 tests/strands/tools/test_validator.py create mode 100644 tests_integ/tools/executors/test_concurrent.py create mode 100644 tests_integ/tools/executors/test_sequential.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 5150060c6..adc554bf4 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -20,7 +20,7 @@ from pydantic import BaseModel from .. import _identifier -from ..event_loop.event_loop import event_loop_cycle, run_tool +from ..event_loop.event_loop import event_loop_cycle from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, @@ -35,6 +35,8 @@ from ..session.session_manager import SessionManager from ..telemetry.metrics import EventLoopMetrics from ..telemetry.tracer import get_tracer, serialize +from ..tools.executors import ConcurrentToolExecutor +from ..tools.executors._executor import ToolExecutor from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages @@ -136,13 +138,14 @@ def caller( "name": normalized_name, "input": kwargs.copy(), } + tool_results: list[ToolResult] = [] + invocation_state = kwargs async def acall() -> ToolResult: - # Pass kwargs as invocation_state - async for event in run_tool(self._agent, tool_use, kwargs): + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): _ = event - return cast(ToolResult, event) + return tool_results[0] def tcall() -> ToolResult: return asyncio.run(acall()) @@ -208,6 +211,7 @@ def __init__( state: Optional[Union[AgentState, dict]] = None, hooks: Optional[list[HookProvider]] = None, session_manager: Optional[SessionManager] = None, + tool_executor: Optional[ToolExecutor] = None, ): """Initialize the Agent with the specified configuration. @@ -250,6 +254,7 @@ def __init__( Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. + tool_executor: Definition of tool execution stragety (e.g., sequential, concurrent, etc.). Raises: ValueError: If agent id contains path separators. @@ -324,6 +329,8 @@ def __init__( if self._session_manager: self.hooks.add_hook(self._session_manager) + self.tool_executor = tool_executor or ConcurrentToolExecutor() + if hooks: for hook in hooks: self.hooks.add_hook(hook) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index b36f73155..524ecc3e8 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -11,22 +11,20 @@ import logging import time import uuid -from typing import TYPE_CHECKING, Any, AsyncGenerator, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator from opentelemetry import trace as trace_api from ..experimental.hooks import ( AfterModelInvocationEvent, - AfterToolInvocationEvent, BeforeModelInvocationEvent, - BeforeToolInvocationEvent, ) from ..hooks import ( MessageAddedEvent, ) from ..telemetry.metrics import Trace from ..telemetry.tracer import get_tracer -from ..tools.executor import run_tools, validate_and_prepare_tools +from ..tools._validator import validate_and_prepare_tools from ..types.content import Message from ..types.exceptions import ( ContextWindowOverflowException, @@ -35,7 +33,7 @@ ModelThrottledException, ) from ..types.streaming import Metrics, StopReason -from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse +from ..types.tools import ToolResult, ToolUse from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached from .streaming import stream_messages @@ -212,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> if stop_reason == "max_tokens": """ Handle max_tokens limit reached by the model. - + When the model reaches its maximum token limit, this represents a potentially unrecoverable state where the model's response was truncated. By default, Strands fails hard with an MaxTokensReachedException to maintain consistency with other failure types. @@ -306,122 +304,6 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - recursive_trace.end() -async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolGenerator: - """Process a tool invocation. - - Looks up the tool in the registry and streams it with the provided parameters. - - Args: - agent: The agent for which the tool is being executed. - tool_use: The tool object to process, containing name and parameters. - invocation_state: Context for the tool invocation, including agent state. - - Yields: - Tool events with the last being the tool result. - """ - logger.debug("tool_use=<%s> | streaming", tool_use) - tool_name = tool_use["name"] - - # Get the tool info - tool_info = agent.tool_registry.dynamic_tools.get(tool_name) - tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name) - - # Add standard arguments to invocation_state for Python tools - invocation_state.update( - { - "model": agent.model, - "system_prompt": agent.system_prompt, - "messages": agent.messages, - "tool_config": ToolConfig( # for backwards compatability - tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()], - toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}), - ), - } - ) - - before_event = agent.hooks.invoke_callbacks( - BeforeToolInvocationEvent( - agent=agent, - selected_tool=tool_func, - tool_use=tool_use, - invocation_state=invocation_state, - ) - ) - - try: - selected_tool = before_event.selected_tool - tool_use = before_event.tool_use - invocation_state = before_event.invocation_state # Get potentially modified invocation_state from hook - - # Check if tool exists - if not selected_tool: - if tool_func == selected_tool: - logger.error( - "tool_name=<%s>, available_tools=<%s> | tool not found in registry", - tool_name, - list(agent.tool_registry.registry.keys()), - ) - else: - logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", - tool_name, - str(tool_use.get("toolUseId")), - ) - - result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), - "status": "error", - "content": [{"text": f"Unknown tool: {tool_name}"}], - } - # for every Before event call, we need to have an AfterEvent call - after_event = agent.hooks.invoke_callbacks( - AfterToolInvocationEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks - result=result, - ) - ) - yield after_event.result - return - - async for event in selected_tool.stream(tool_use, invocation_state): - yield event - - result = event - - after_event = agent.hooks.invoke_callbacks( - AfterToolInvocationEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks - result=result, - ) - ) - yield after_event.result - - except Exception as e: - logger.exception("tool_name=<%s> | failed to process tool", tool_name) - error_result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), - "status": "error", - "content": [{"text": f"Error: {str(e)}"}], - } - after_event = agent.hooks.invoke_callbacks( - AfterToolInvocationEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks - result=error_result, - exception=e, - ) - ) - yield after_event.result - - async def _handle_tool_execution( stop_reason: StopReason, message: Message, @@ -431,18 +313,12 @@ async def _handle_tool_execution( cycle_start_time: float, invocation_state: dict[str, Any], ) -> AsyncGenerator[dict[str, Any], None]: - tool_uses: list[ToolUse] = [] - tool_results: list[ToolResult] = [] - invalid_tool_use_ids: list[str] = [] - - """ - Handles the execution of tools requested by the model during an event loop cycle. + """Handles the execution of tools requested by the model during an event loop cycle. Args: stop_reason: The reason the model stopped generating. message: The message from the model that may contain tool use requests. - event_loop_metrics: Metrics tracking object for the event loop. - event_loop_parent_span: Span for the parent of this event loop. + agent: Agent for which tools are being executed. cycle_trace: Trace object for the current event loop cycle. cycle_span: Span object for tracing the cycle (type may vary). cycle_start_time: Start time of the current cycle. @@ -456,23 +332,18 @@ async def _handle_tool_execution( - The updated event loop metrics, - The updated request state. """ - validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) + tool_uses: list[ToolUse] = [] + tool_results: list[ToolResult] = [] + invalid_tool_use_ids: list[str] = [] + validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) + tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] if not tool_uses: yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} return - def tool_handler(tool_use: ToolUse) -> ToolGenerator: - return run_tool(agent, tool_use, invocation_state) - - tool_events = run_tools( - handler=tool_handler, - tool_uses=tool_uses, - event_loop_metrics=agent.event_loop_metrics, - invalid_tool_use_ids=invalid_tool_use_ids, - tool_results=tool_results, - cycle_trace=cycle_trace, - parent_span=cycle_span, + tool_events = agent.tool_executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state ) async for tool_event in tool_events: yield tool_event diff --git a/src/strands/tools/_validator.py b/src/strands/tools/_validator.py new file mode 100644 index 000000000..77aa57e87 --- /dev/null +++ b/src/strands/tools/_validator.py @@ -0,0 +1,45 @@ +"""Tool validation utilities.""" + +from ..tools.tools import InvalidToolUseNameException, validate_tool_use +from ..types.content import Message +from ..types.tools import ToolResult, ToolUse + + +def validate_and_prepare_tools( + message: Message, + tool_uses: list[ToolUse], + tool_results: list[ToolResult], + invalid_tool_use_ids: list[str], +) -> None: + """Validate tool uses and prepare them for execution. + + Args: + message: Current message. + tool_uses: List to populate with tool uses. + tool_results: List to populate with tool results for invalid tools. + invalid_tool_use_ids: List to populate with invalid tool use IDs. + """ + # Extract tool uses from message + for content in message["content"]: + if isinstance(content, dict) and "toolUse" in content: + tool_uses.append(content["toolUse"]) + + # Validate tool uses + # Avoid modifying original `tool_uses` variable during iteration + tool_uses_copy = tool_uses.copy() + for tool in tool_uses_copy: + try: + validate_tool_use(tool) + except InvalidToolUseNameException as e: + # Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context + tool_uses.remove(tool) + tool["name"] = "INVALID_TOOL_NAME" + invalid_tool_use_ids.append(tool["toolUseId"]) + tool_uses.append(tool) + tool_results.append( + { + "toolUseId": tool["toolUseId"], + "status": "error", + "content": [{"text": f"Error: {str(e)}"}], + } + ) diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py deleted file mode 100644 index d90f9a5aa..000000000 --- a/src/strands/tools/executor.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Tool execution functionality for the event loop.""" - -import asyncio -import logging -import time -from typing import Any, Optional, cast - -from opentelemetry import trace as trace_api - -from ..telemetry.metrics import EventLoopMetrics, Trace -from ..telemetry.tracer import get_tracer -from ..tools.tools import InvalidToolUseNameException, validate_tool_use -from ..types.content import Message -from ..types.tools import RunToolHandler, ToolGenerator, ToolResult, ToolUse - -logger = logging.getLogger(__name__) - - -async def run_tools( - handler: RunToolHandler, - tool_uses: list[ToolUse], - event_loop_metrics: EventLoopMetrics, - invalid_tool_use_ids: list[str], - tool_results: list[ToolResult], - cycle_trace: Trace, - parent_span: Optional[trace_api.Span] = None, -) -> ToolGenerator: - """Execute tools concurrently. - - Args: - handler: Tool handler processing function. - tool_uses: List of tool uses to execute. - event_loop_metrics: Metrics collection object. - invalid_tool_use_ids: List of invalid tool use IDs. - tool_results: List to populate with tool results. - cycle_trace: Parent trace for the current cycle. - parent_span: Parent span for the current cycle. - - Yields: - Events of the tool stream. Tool results are appended to `tool_results`. - """ - - async def work( - tool_use: ToolUse, - worker_id: int, - worker_queue: asyncio.Queue, - worker_event: asyncio.Event, - stop_event: object, - ) -> ToolResult: - tracer = get_tracer() - tool_call_span = tracer.start_tool_call_span(tool_use, parent_span) - - tool_name = tool_use["name"] - tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) - tool_start_time = time.time() - with trace_api.use_span(tool_call_span): - try: - async for event in handler(tool_use): - worker_queue.put_nowait((worker_id, event)) - await worker_event.wait() - worker_event.clear() - - result = cast(ToolResult, event) - finally: - worker_queue.put_nowait((worker_id, stop_event)) - - tool_success = result.get("status") == "success" - tool_duration = time.time() - tool_start_time - message = Message(role="user", content=[{"toolResult": result}]) - event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) - cycle_trace.add_child(tool_trace) - - tracer.end_tool_call_span(tool_call_span, result) - - return result - - tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] - worker_queue: asyncio.Queue[tuple[int, Any]] = asyncio.Queue() - worker_events = [asyncio.Event() for _ in tool_uses] - stop_event = object() - - workers = [ - asyncio.create_task(work(tool_use, worker_id, worker_queue, worker_events[worker_id], stop_event)) - for worker_id, tool_use in enumerate(tool_uses) - ] - - worker_count = len(workers) - while worker_count: - worker_id, event = await worker_queue.get() - if event is stop_event: - worker_count -= 1 - continue - - yield event - worker_events[worker_id].set() - - tool_results.extend([worker.result() for worker in workers]) - - -def validate_and_prepare_tools( - message: Message, - tool_uses: list[ToolUse], - tool_results: list[ToolResult], - invalid_tool_use_ids: list[str], -) -> None: - """Validate tool uses and prepare them for execution. - - Args: - message: Current message. - tool_uses: List to populate with tool uses. - tool_results: List to populate with tool results for invalid tools. - invalid_tool_use_ids: List to populate with invalid tool use IDs. - """ - # Extract tool uses from message - for content in message["content"]: - if isinstance(content, dict) and "toolUse" in content: - tool_uses.append(content["toolUse"]) - - # Validate tool uses - # Avoid modifying original `tool_uses` variable during iteration - tool_uses_copy = tool_uses.copy() - for tool in tool_uses_copy: - try: - validate_tool_use(tool) - except InvalidToolUseNameException as e: - # Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context - tool_uses.remove(tool) - tool["name"] = "INVALID_TOOL_NAME" - invalid_tool_use_ids.append(tool["toolUseId"]) - tool_uses.append(tool) - tool_results.append( - { - "toolUseId": tool["toolUseId"], - "status": "error", - "content": [{"text": f"Error: {str(e)}"}], - } - ) diff --git a/src/strands/tools/executors/__init__.py b/src/strands/tools/executors/__init__.py new file mode 100644 index 000000000..c8be812e4 --- /dev/null +++ b/src/strands/tools/executors/__init__.py @@ -0,0 +1,16 @@ +"""Tool executors for the Strands SDK. + +This package provides different execution strategies for tools, allowing users to customize +how tools are executed (e.g., concurrent, sequential, with custom thread pools, etc.). +""" + +from . import concurrent, sequential +from .concurrent import ConcurrentToolExecutor +from .sequential import SequentialToolExecutor + +__all__ = [ + "ConcurrentToolExecutor", + "SequentialToolExecutor", + "concurrent", + "sequential", +] diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py new file mode 100644 index 000000000..9999b77fc --- /dev/null +++ b/src/strands/tools/executors/_executor.py @@ -0,0 +1,227 @@ +"""Abstract base class for tool executors. + +Tool executors are responsible for determining how tools are executed (e.g., concurrently, sequentially, with custom +thread pools, etc.). +""" + +import abc +import logging +import time +from typing import TYPE_CHECKING, Any, cast + +from opentelemetry import trace as trace_api + +from ...experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent +from ...telemetry.metrics import Trace +from ...telemetry.tracer import get_tracer +from ...types.content import Message +from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse + +if TYPE_CHECKING: # pragma: no cover + from ...agent import Agent + +logger = logging.getLogger(__name__) + + +class ToolExecutor(abc.ABC): + """Abstract base class for tool executors.""" + + @staticmethod + async def _stream( + agent: "Agent", + tool_use: ToolUse, + tool_results: list[ToolResult], + invocation_state: dict[str, Any], + **kwargs: Any, + ) -> ToolGenerator: + """Stream tool events. + + This method adds additional logic to the stream invocation including: + + - Tool lookup and validation + - Before/after hook execution + - Tracing and metrics collection + - Error handling and recovery + + Args: + agent: The agent for which the tool is being executed. + tool_use: Metadata and inputs for the tool to be executed. + tool_results: List of tool results from each tool execution. + invocation_state: Context for the tool invocation. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Tool events with the last being the tool result. + """ + logger.debug("tool_use=<%s> | streaming", tool_use) + tool_name = tool_use["name"] + + tool_info = agent.tool_registry.dynamic_tools.get(tool_name) + tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name) + + invocation_state.update( + { + "model": agent.model, + "messages": agent.messages, + "system_prompt": agent.system_prompt, + "tool_config": ToolConfig( # for backwards compatibility + tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()], + toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}), + ), + } + ) + + before_event = agent.hooks.invoke_callbacks( + BeforeToolInvocationEvent( + agent=agent, + selected_tool=tool_func, + tool_use=tool_use, + invocation_state=invocation_state, + ) + ) + + try: + selected_tool = before_event.selected_tool + tool_use = before_event.tool_use + invocation_state = before_event.invocation_state + + if not selected_tool: + if tool_func == selected_tool: + logger.error( + "tool_name=<%s>, available_tools=<%s> | tool not found in registry", + tool_name, + list(agent.tool_registry.registry.keys()), + ) + else: + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", + tool_name, + str(tool_use.get("toolUseId")), + ) + + result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"Unknown tool: {tool_name}"}], + } + after_event = agent.hooks.invoke_callbacks( + AfterToolInvocationEvent( + agent=agent, + selected_tool=selected_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + ) + ) + yield after_event.result + tool_results.append(after_event.result) + return + + async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): + yield event + + result = cast(ToolResult, event) + + after_event = agent.hooks.invoke_callbacks( + AfterToolInvocationEvent( + agent=agent, + selected_tool=selected_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + ) + ) + yield after_event.result + tool_results.append(after_event.result) + + except Exception as e: + logger.exception("tool_name=<%s> | failed to process tool", tool_name) + error_result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"Error: {str(e)}"}], + } + after_event = agent.hooks.invoke_callbacks( + AfterToolInvocationEvent( + agent=agent, + selected_tool=selected_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=error_result, + exception=e, + ) + ) + yield after_event.result + tool_results.append(after_event.result) + + @staticmethod + async def _stream_with_trace( + agent: "Agent", + tool_use: ToolUse, + tool_results: list[ToolResult], + cycle_trace: Trace, + cycle_span: Any, + invocation_state: dict[str, Any], + **kwargs: Any, + ) -> ToolGenerator: + """Execute tool with tracing and metrics collection. + + Args: + agent: The agent for which the tool is being executed. + tool_use: Metadata and inputs for the tool to be executed. + tool_results: List of tool results from each tool execution. + cycle_trace: Trace object for the current event loop cycle. + cycle_span: Span object for tracing the cycle. + invocation_state: Context for the tool invocation. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Tool events with the last being the tool result. + """ + tool_name = tool_use["name"] + + tracer = get_tracer() + + tool_call_span = tracer.start_tool_call_span(tool_use, cycle_span) + tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) + tool_start_time = time.time() + + with trace_api.use_span(tool_call_span): + async for event in ToolExecutor._stream(agent, tool_use, tool_results, invocation_state, **kwargs): + yield event + + result = cast(ToolResult, event) + + tool_success = result.get("status") == "success" + tool_duration = time.time() - tool_start_time + message = Message(role="user", content=[{"toolResult": result}]) + agent.event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) + cycle_trace.add_child(tool_trace) + + tracer.end_tool_call_span(tool_call_span, result) + + @abc.abstractmethod + # pragma: no cover + def _execute( + self, + agent: "Agent", + tool_uses: list[ToolUse], + tool_results: list[ToolResult], + cycle_trace: Trace, + cycle_span: Any, + invocation_state: dict[str, Any], + ) -> ToolGenerator: + """Execute the given tools according to this executor's strategy. + + Args: + agent: The agent for which tools are being executed. + tool_uses: Metadata and inputs for the tools to be executed. + tool_results: List of tool results from each tool execution. + cycle_trace: Trace object for the current event loop cycle. + cycle_span: Span object for tracing the cycle. + invocation_state: Context for the tool invocation. + + Yields: + Events from the tool execution stream. + """ + pass diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py new file mode 100644 index 000000000..7d5dd7fe7 --- /dev/null +++ b/src/strands/tools/executors/concurrent.py @@ -0,0 +1,113 @@ +"""Concurrent tool executor implementation.""" + +import asyncio +from typing import TYPE_CHECKING, Any + +from typing_extensions import override + +from ...telemetry.metrics import Trace +from ...types.tools import ToolGenerator, ToolResult, ToolUse +from ._executor import ToolExecutor + +if TYPE_CHECKING: # pragma: no cover + from ...agent import Agent + + +class ConcurrentToolExecutor(ToolExecutor): + """Concurrent tool executor.""" + + @override + async def _execute( + self, + agent: "Agent", + tool_uses: list[ToolUse], + tool_results: list[ToolResult], + cycle_trace: Trace, + cycle_span: Any, + invocation_state: dict[str, Any], + ) -> ToolGenerator: + """Execute tools concurrently. + + Args: + agent: The agent for which tools are being executed. + tool_uses: Metadata and inputs for the tools to be executed. + tool_results: List of tool results from each tool execution. + cycle_trace: Trace object for the current event loop cycle. + cycle_span: Span object for tracing the cycle. + invocation_state: Context for the tool invocation. + + Yields: + Events from the tool execution stream. + """ + task_queue: asyncio.Queue[tuple[int, Any]] = asyncio.Queue() + task_events = [asyncio.Event() for _ in tool_uses] + stop_event = object() + + tasks = [ + asyncio.create_task( + self._task( + agent, + tool_use, + tool_results, + cycle_trace, + cycle_span, + invocation_state, + task_id, + task_queue, + task_events[task_id], + stop_event, + ) + ) + for task_id, tool_use in enumerate(tool_uses) + ] + + task_count = len(tasks) + while task_count: + task_id, event = await task_queue.get() + if event is stop_event: + task_count -= 1 + continue + + yield event + task_events[task_id].set() + + asyncio.gather(*tasks) + + async def _task( + self, + agent: "Agent", + tool_use: ToolUse, + tool_results: list[ToolResult], + cycle_trace: Trace, + cycle_span: Any, + invocation_state: dict[str, Any], + task_id: int, + task_queue: asyncio.Queue, + task_event: asyncio.Event, + stop_event: object, + ) -> None: + """Execute a single tool and put results in the task queue. + + Args: + agent: The agent executing the tool. + tool_use: Tool use metadata and inputs. + tool_results: List of tool results from each tool execution. + cycle_trace: Trace object for the current event loop cycle. + cycle_span: Span object for tracing the cycle. + invocation_state: Context for tool execution. + task_id: Unique identifier for this task. + task_queue: Queue to put tool events into. + task_event: Event to signal when task can continue. + stop_event: Sentinel object to signal task completion. + """ + try: + events = ToolExecutor._stream_with_trace( + agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state + ) + async for event in events: + task_queue.put_nowait((task_id, event)) + await task_event.wait() + task_event.clear() + + finally: + task_queue.put_nowait((task_id, stop_event)) diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py new file mode 100644 index 000000000..55b26f6d3 --- /dev/null +++ b/src/strands/tools/executors/sequential.py @@ -0,0 +1,46 @@ +"""Sequential tool executor implementation.""" + +from typing import TYPE_CHECKING, Any + +from typing_extensions import override + +from ...telemetry.metrics import Trace +from ...types.tools import ToolGenerator, ToolResult, ToolUse +from ._executor import ToolExecutor + +if TYPE_CHECKING: # pragma: no cover + from ...agent import Agent + + +class SequentialToolExecutor(ToolExecutor): + """Sequential tool executor.""" + + @override + async def _execute( + self, + agent: "Agent", + tool_uses: list[ToolUse], + tool_results: list[ToolResult], + cycle_trace: Trace, + cycle_span: Any, + invocation_state: dict[str, Any], + ) -> ToolGenerator: + """Execute tools sequentially. + + Args: + agent: The agent for which tools are being executed. + tool_uses: Metadata and inputs for the tools to be executed. + tool_results: List of tool results from each tool execution. + cycle_trace: Trace object for the current event loop cycle. + cycle_span: Span object for tracing the cycle. + invocation_state: Context for the tool invocation. + + Yields: + Events from the tool execution stream. + """ + for tool_use in tool_uses: + events = ToolExecutor._stream_with_trace( + agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state + ) + async for event in events: + yield event diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 7e769c6d7..279e2a06e 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -73,12 +73,6 @@ def mock_event_loop_cycle(): yield mock -@pytest.fixture -def mock_run_tool(): - with unittest.mock.patch("strands.agent.agent.run_tool") as mock: - yield mock - - @pytest.fixture def tool_registry(): return strands.tools.registry.ToolRegistry() @@ -888,9 +882,7 @@ def test_agent_init_with_no_model_or_model_id(): assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID -def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, mock_run_tool, agenerator): - mock_run_tool.return_value = agenerator([{}]) - +def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, agenerator): @strands.tools.tool(name="system_prompter") def function(system_prompt: str) -> str: return system_prompt @@ -899,22 +891,12 @@ def function(system_prompt: str) -> str: mock_randint.return_value = 1 - agent.tool.system_prompter(system_prompt="tool prompt") - - mock_run_tool.assert_called_with( - agent, - { - "toolUseId": "tooluse_system_prompter_1", - "name": "system_prompter", - "input": {"system_prompt": "tool prompt"}, - }, - {"system_prompt": "tool prompt"}, - ) - + tru_result = agent.tool.system_prompter(system_prompt="tool prompt") + exp_result = {"toolUseId": "tooluse_system_prompter_1", "status": "success", "content": [{"text": "tool prompt"}]} + assert tru_result == exp_result -def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint, mock_run_tool, agenerator): - mock_run_tool.return_value = agenerator([{}]) +def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint, agenerator): tool_name = "system-prompter" @strands.tools.tool(name=tool_name) @@ -925,19 +907,9 @@ def function(system_prompt: str) -> str: mock_randint.return_value = 1 - agent.tool.system_prompter(system_prompt="tool prompt") - - # Verify the correct tool was invoked - assert mock_run_tool.call_count == 1 - tru_tool_use = mock_run_tool.call_args.args[1] - exp_tool_use = { - # Note that the tool-use uses the "python safe" name - "toolUseId": "tooluse_system_prompter_1", - # But the name of the tool is the one in the registry - "name": tool_name, - "input": {"system_prompt": "tool prompt"}, - } - assert tru_tool_use == exp_tool_use + tru_result = agent.tool.system_prompter(system_prompt="tool prompt") + exp_result = {"toolUseId": "tooluse_system_prompter_1", "status": "success", "content": [{"text": "tool prompt"}]} + assert tru_result == exp_result def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint): diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 191ab51ba..c76514ac8 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,23 +1,20 @@ import concurrent import unittest.mock -from unittest.mock import ANY, MagicMock, call, patch +from unittest.mock import MagicMock, call, patch import pytest import strands import strands.telemetry -from strands.event_loop.event_loop import run_tool from strands.experimental.hooks import ( AfterModelInvocationEvent, AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, ) -from strands.hooks import ( - HookProvider, - HookRegistry, -) +from strands.hooks import HookRegistry from strands.telemetry.metrics import EventLoopMetrics +from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry from strands.types.exceptions import ( ContextWindowOverflowException, @@ -131,7 +128,12 @@ def hook_provider(hook_registry): @pytest.fixture -def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry): +def tool_executor(): + return SequentialToolExecutor() + + +@pytest.fixture +def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry, tool_executor): mock = unittest.mock.Mock(name="agent") mock.config.cache_points = [] mock.model = model @@ -141,6 +143,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.thread_pool = thread_pool mock.event_loop_metrics = EventLoopMetrics() mock.hooks = hook_registry + mock.tool_executor = tool_executor return mock @@ -812,260 +815,6 @@ async def test_prepare_next_cycle_in_tool_execution(agent, model, tool_stream, a ) -@pytest.mark.asyncio -async def test_run_tool(agent, tool, alist): - process = run_tool( - agent, - tool_use={"toolUseId": "tool_use_id", "name": tool.tool_name, "input": {"random_string": "a_string"}}, - invocation_state={}, - ) - - tru_result = (await alist(process))[-1] - exp_result = {"toolUseId": "tool_use_id", "status": "success", "content": [{"text": "a_string"}]} - - assert tru_result == exp_result - - -@pytest.mark.asyncio -async def test_run_tool_missing_tool(agent, alist): - process = run_tool( - agent, - tool_use={"toolUseId": "missing", "name": "missing", "input": {}}, - invocation_state={}, - ) - - tru_events = await alist(process) - exp_events = [ - { - "toolUseId": "missing", - "status": "error", - "content": [{"text": "Unknown tool: missing"}], - }, - ] - - assert tru_events == exp_events - - -@pytest.mark.asyncio -async def test_run_tool_hooks(agent, hook_provider, tool_times_2, alist): - """Test that the correct hooks are emitted.""" - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": tool_times_2.tool_name, "input": {"x": 5}}, - invocation_state={}, - ) - await alist(process) - - assert len(hook_provider.events_received) == 2 - - assert hook_provider.events_received[0] == BeforeToolInvocationEvent( - agent=agent, - selected_tool=tool_times_2, - tool_use={"input": {"x": 5}, "name": "multiply_by_2", "toolUseId": "test"}, - invocation_state=ANY, - ) - - assert hook_provider.events_received[1] == AfterToolInvocationEvent( - agent=agent, - selected_tool=tool_times_2, - exception=None, - tool_use={"toolUseId": "test", "name": tool_times_2.tool_name, "input": {"x": 5}}, - result={"toolUseId": "test", "status": "success", "content": [{"text": "10"}]}, - invocation_state=ANY, - ) - - -@pytest.mark.asyncio -async def test_run_tool_hooks_on_missing_tool(agent, hook_provider, alist): - """Test that AfterToolInvocation hook is invoked even when tool throws exception.""" - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}}, - invocation_state={}, - ) - await alist(process) - - assert len(hook_provider.events_received) == 2 - - assert hook_provider.events_received[0] == BeforeToolInvocationEvent( - agent=agent, - selected_tool=None, - tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"}, - invocation_state=ANY, - ) - - assert hook_provider.events_received[1] == AfterToolInvocationEvent( - agent=agent, - selected_tool=None, - tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"}, - invocation_state=ANY, - result={"content": [{"text": "Unknown tool: missing_tool"}], "status": "error", "toolUseId": "test"}, - exception=None, - ) - - -@pytest.mark.asyncio -async def test_run_tool_hook_after_tool_invocation_on_exception(agent, tool_registry, hook_provider, alist): - """Test that AfterToolInvocation hook is invoked even when tool throws exception.""" - error = ValueError("Tool failed") - - failing_tool = MagicMock() - failing_tool.tool_name = "failing_tool" - - failing_tool.stream.side_effect = error - - tool_registry.register_tool(failing_tool) - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": "failing_tool", "input": {"x": 5}}, - invocation_state={}, - ) - await alist(process) - - assert hook_provider.events_received[1] == AfterToolInvocationEvent( - agent=agent, - selected_tool=failing_tool, - tool_use={"input": {"x": 5}, "name": "failing_tool", "toolUseId": "test"}, - invocation_state=ANY, - result={"content": [{"text": "Error: Tool failed"}], "status": "error", "toolUseId": "test"}, - exception=error, - ) - - -@pytest.mark.asyncio -async def test_run_tool_hook_before_tool_invocation_updates(agent, tool_times_5, hook_registry, hook_provider, alist): - """Test that modifying properties on BeforeToolInvocation takes effect.""" - - updated_tool_use = {"toolUseId": "modified", "name": "replacement_tool", "input": {"x": 3}} - - def modify_hook(event: BeforeToolInvocationEvent): - # Modify selected_tool to use replacement_tool - event.selected_tool = tool_times_5 - # Modify tool_use to change toolUseId - event.tool_use = updated_tool_use - - hook_registry.add_callback(BeforeToolInvocationEvent, modify_hook) - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "original", "name": "original_tool", "input": {"x": 1}}, - invocation_state={}, - ) - result = (await alist(process))[-1] - - # Should use replacement_tool (5 * 3 = 15) instead of original_tool (1 * 2 = 2) - assert result == {"toolUseId": "modified", "status": "success", "content": [{"text": "15"}]} - - assert hook_provider.events_received[1] == AfterToolInvocationEvent( - agent=agent, - selected_tool=tool_times_5, - tool_use=updated_tool_use, - invocation_state=ANY, - result={"content": [{"text": "15"}], "status": "success", "toolUseId": "modified"}, - exception=None, - ) - - -@pytest.mark.asyncio -async def test_run_tool_hook_after_tool_invocation_updates(agent, tool_times_2, hook_registry, alist): - """Test that modifying properties on AfterToolInvocation takes effect.""" - - updated_result = {"toolUseId": "modified", "status": "success", "content": [{"text": "modified_result"}]} - - def modify_hook(event: AfterToolInvocationEvent): - # Modify result to change the output - event.result = updated_result - - hook_registry.add_callback(AfterToolInvocationEvent, modify_hook) - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": tool_times_2.tool_name, "input": {"x": 5}}, - invocation_state={}, - ) - - result = (await alist(process))[-1] - assert result == updated_result - - -@pytest.mark.asyncio -async def test_run_tool_hook_after_tool_invocation_updates_with_missing_tool(agent, hook_registry, alist): - """Test that modifying properties on AfterToolInvocation takes effect.""" - - updated_result = {"toolUseId": "modified", "status": "success", "content": [{"text": "modified_result"}]} - - def modify_hook(event: AfterToolInvocationEvent): - # Modify result to change the output - event.result = updated_result - - hook_registry.add_callback(AfterToolInvocationEvent, modify_hook) - - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}}, - invocation_state={}, - ) - - result = (await alist(process))[-1] - assert result == updated_result - - -@pytest.mark.asyncio -async def test_run_tool_hook_update_result_with_missing_tool(agent, tool_registry, hook_registry, alist): - """Test that modifying properties on AfterToolInvocation takes effect.""" - - @strands.tool - def test_quota(): - return "9" - - tool_registry.register_tool(test_quota) - - class ExampleProvider(HookProvider): - def register_hooks(self, registry: "HookRegistry") -> None: - registry.add_callback(BeforeToolInvocationEvent, self.before_tool_call) - registry.add_callback(AfterToolInvocationEvent, self.after_tool_call) - - def before_tool_call(self, event: BeforeToolInvocationEvent): - if event.tool_use.get("name") == "test_quota": - event.selected_tool = None - - def after_tool_call(self, event: AfterToolInvocationEvent): - if event.tool_use.get("name") == "test_quota": - event.result = { - "status": "error", - "toolUseId": "test", - "content": [{"text": "This tool has been used too many times!"}], - } - - hook_registry.add_hook(ExampleProvider()) - - with patch.object(strands.event_loop.event_loop, "logger") as mock_logger: - process = run_tool( - agent=agent, - tool_use={"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}, - invocation_state={}, - ) - - result = (await alist(process))[-1] - - assert result == { - "status": "error", - "toolUseId": "test", - "content": [{"text": "This tool has been used too many times!"}], - } - - assert mock_logger.debug.call_args_list == [ - call("tool_use=<%s> | streaming", {"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}), - call( - "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", - "test_quota", - "test", - ), - ] - - @pytest.mark.asyncio async def test_event_loop_cycle_exception_model_hooks(mock_time, agent, model, agenerator, alist, hook_provider): """Test that model hooks are correctly emitted even when throttled.""" diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py new file mode 100644 index 000000000..1576b7578 --- /dev/null +++ b/tests/strands/tools/executors/conftest.py @@ -0,0 +1,116 @@ +import threading +import unittest.mock + +import pytest + +import strands +from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent +from strands.hooks import HookRegistry +from strands.tools.registry import ToolRegistry + + +@pytest.fixture +def hook_events(): + return [] + + +@pytest.fixture +def tool_hook(hook_events): + def callback(event): + hook_events.append(event) + return event + + return callback + + +@pytest.fixture +def hook_registry(tool_hook): + registry = HookRegistry() + registry.add_callback(BeforeToolInvocationEvent, tool_hook) + registry.add_callback(AfterToolInvocationEvent, tool_hook) + return registry + + +@pytest.fixture +def tool_events(): + return [] + + +@pytest.fixture +def weather_tool(): + @strands.tool(name="weather_tool") + def func(): + return "sunny" + + return func + + +@pytest.fixture +def temperature_tool(): + @strands.tool(name="temperature_tool") + def func(): + return "75F" + + return func + + +@pytest.fixture +def exception_tool(): + @strands.tool(name="exception_tool") + def func(): + pass + + async def mock_stream(_tool_use, _invocation_state): + raise RuntimeError("Tool error") + yield # make generator + + func.stream = mock_stream + return func + + +@pytest.fixture +def thread_tool(tool_events): + @strands.tool(name="thread_tool") + def func(): + tool_events.append({"thread_name": threading.current_thread().name}) + return "threaded" + + return func + + +@pytest.fixture +def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool): + registry = ToolRegistry() + registry.register_tool(weather_tool) + registry.register_tool(temperature_tool) + registry.register_tool(exception_tool) + registry.register_tool(thread_tool) + return registry + + +@pytest.fixture +def agent(tool_registry, hook_registry): + mock_agent = unittest.mock.Mock() + mock_agent.tool_registry = tool_registry + mock_agent.hooks = hook_registry + return mock_agent + + +@pytest.fixture +def tool_results(): + return [] + + +@pytest.fixture +def cycle_trace(): + return unittest.mock.Mock() + + +@pytest.fixture +def cycle_span(): + return unittest.mock.Mock() + + +@pytest.fixture +def invocation_state(): + return {} diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py new file mode 100644 index 000000000..7e0d6c2df --- /dev/null +++ b/tests/strands/tools/executors/test_concurrent.py @@ -0,0 +1,32 @@ +import pytest + +from strands.tools.executors import ConcurrentToolExecutor + + +@pytest.fixture +def executor(): + return ConcurrentToolExecutor() + + +@pytest.mark.asyncio +async def test_concurrent_executor_execute( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + tool_uses = [ + {"name": "weather_tool", "toolUseId": "1", "input": {}}, + {"name": "temperature_tool", "toolUseId": "2", "input": {}}, + ] + stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + + tru_events = sorted(await alist(stream), key=lambda event: event.get("toolUseId")) + exp_events = [ + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, + {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, + ] + assert tru_events == exp_events + + tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId")) + exp_results = [exp_events[1], exp_events[3]] + assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py new file mode 100644 index 000000000..edbad3939 --- /dev/null +++ b/tests/strands/tools/executors/test_executor.py @@ -0,0 +1,144 @@ +import unittest.mock + +import pytest + +import strands +from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent +from strands.telemetry.metrics import Trace +from strands.tools.executors._executor import ToolExecutor + + +@pytest.fixture +def executor_cls(): + class ClsExecutor(ToolExecutor): + def _execute(self, _agent, _tool_uses, _tool_results, _invocation_state): + raise NotImplementedError + + return ClsExecutor + + +@pytest.fixture +def executor(executor_cls): + return executor_cls() + + +@pytest.fixture +def tracer(): + with unittest.mock.patch.object(strands.tools.executors._executor, "get_tracer") as mock_get_tracer: + yield mock_get_tracer.return_value + + +@pytest.mark.asyncio +async def test_executor_stream_yields_result( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist +): + tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1]] + assert tru_results == exp_results + + tru_hook_events = hook_events + exp_hook_events = [ + BeforeToolInvocationEvent( + agent=agent, + selected_tool=weather_tool, + tool_use=tool_use, + invocation_state=invocation_state, + ), + AfterToolInvocationEvent( + agent=agent, + selected_tool=weather_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=exp_results[0], + ), + ] + assert tru_hook_events == exp_hook_events + + +@pytest.mark.asyncio +async def test_executor_stream_yields_tool_error( + executor, agent, tool_results, invocation_state, hook_events, exception_tool, alist +): + tool_use = {"name": "exception_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [{"toolUseId": "1", "status": "error", "content": [{"text": "Error: Tool error"}]}] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1]] + assert tru_results == exp_results + + tru_hook_after_event = hook_events[-1] + exp_hook_after_event = AfterToolInvocationEvent( + agent=agent, + selected_tool=exception_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=exp_results[0], + exception=unittest.mock.ANY, + ) + assert tru_hook_after_event == exp_hook_after_event + + +@pytest.mark.asyncio +async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results, invocation_state, hook_events, alist): + tool_use = {"name": "unknown_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [{"toolUseId": "1", "status": "error", "content": [{"text": "Unknown tool: unknown_tool"}]}] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1]] + assert tru_results == exp_results + + tru_hook_after_event = hook_events[-1] + exp_hook_after_event = AfterToolInvocationEvent( + agent=agent, + selected_tool=None, + tool_use=tool_use, + invocation_state=invocation_state, + result=exp_results[0], + ) + assert tru_hook_after_event == exp_hook_after_event + + +@pytest.mark.asyncio +async def test_executor_stream_with_trace( + executor, tracer, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream_with_trace(agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1]] + assert tru_results == exp_results + + tracer.start_tool_call_span.assert_called_once_with(tool_use, cycle_span) + tracer.end_tool_call_span.assert_called_once_with( + tracer.start_tool_call_span.return_value, + {"content": [{"text": "sunny"}], "status": "success", "toolUseId": "1"}, + ) + + cycle_trace.add_child.assert_called_once() + assert isinstance(cycle_trace.add_child.call_args[0][0], Trace) diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py new file mode 100644 index 000000000..d9b32c129 --- /dev/null +++ b/tests/strands/tools/executors/test_sequential.py @@ -0,0 +1,32 @@ +import pytest + +from strands.tools.executors import SequentialToolExecutor + + +@pytest.fixture +def executor(): + return SequentialToolExecutor() + + +@pytest.mark.asyncio +async def test_sequential_executor_execute( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + tool_uses = [ + {"name": "weather_tool", "toolUseId": "1", "input": {}}, + {"name": "temperature_tool", "toolUseId": "2", "input": {}}, + ] + stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, + {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[1], exp_events[2]] + assert tru_results == exp_results diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py deleted file mode 100644 index 04d4ea657..000000000 --- a/tests/strands/tools/test_executor.py +++ /dev/null @@ -1,440 +0,0 @@ -import unittest.mock -import uuid - -import pytest - -import strands -import strands.telemetry -from strands.types.content import Message - - -@pytest.fixture(autouse=True) -def moto_autouse(moto_env): - _ = moto_env - - -@pytest.fixture -def tool_handler(request): - async def handler(tool_use): - yield {"event": "abc"} - yield { - **params, - "toolUseId": tool_use["toolUseId"], - } - - params = { - "content": [{"text": "test result"}], - "status": "success", - } - if hasattr(request, "param"): - params.update(request.param) - - return handler - - -@pytest.fixture -def tool_use(): - return {"toolUseId": "t1", "name": "test_tool", "input": {"key": "value"}} - - -@pytest.fixture -def tool_uses(request, tool_use): - return request.param if hasattr(request, "param") else [tool_use] - - -@pytest.fixture -def mock_metrics_client(): - with unittest.mock.patch("strands.telemetry.MetricsClient") as mock_metrics_client: - yield mock_metrics_client - - -@pytest.fixture -def event_loop_metrics(): - return strands.telemetry.metrics.EventLoopMetrics() - - -@pytest.fixture -def invalid_tool_use_ids(request): - return request.param if hasattr(request, "param") else [] - - -@pytest.fixture -def cycle_trace(): - with unittest.mock.patch.object(uuid, "uuid4", return_value="trace1"): - return strands.telemetry.metrics.Trace(name="test trace", raw_name="raw_name") - - -@pytest.mark.asyncio -async def test_run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - tool_results = [] - - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - ) - - tru_events = await alist(stream) - exp_events = [ - {"event": "abc"}, - { - "content": [ - { - "text": "test result", - }, - ], - "status": "success", - "toolUseId": "t1", - }, - ] - - tru_results = tool_results - exp_results = [exp_events[-1]] - - assert tru_events == exp_events and tru_results == exp_results - - -@pytest.mark.parametrize("invalid_tool_use_ids", [["t1"]], indirect=True) -@pytest.mark.asyncio -async def test_run_tools_invalid_tool( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - tool_results = [] - - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - ) - await alist(stream) - - tru_results = tool_results - exp_results = [] - - assert tru_results == exp_results - - -@pytest.mark.parametrize("tool_handler", [{"status": "failed"}], indirect=True) -@pytest.mark.asyncio -async def test_run_tools_failed_tool( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - tool_results = [] - - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - ) - await alist(stream) - - tru_results = tool_results - exp_results = [ - { - "content": [ - { - "text": "test result", - }, - ], - "status": "failed", - "toolUseId": "t1", - }, - ] - - assert tru_results == exp_results - - -@pytest.mark.parametrize( - ("tool_uses", "invalid_tool_use_ids"), - [ - ( - [ - { - "toolUseId": "t1", - "name": "test_tool_success", - "input": {"key": "value1"}, - }, - { - "toolUseId": "t2", - "name": "test_tool_invalid", - "input": {"key": "value2"}, - }, - ], - ["t2"], - ), - ], - indirect=True, -) -@pytest.mark.asyncio -async def test_run_tools_sequential( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - tool_results = [] - - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - None, # tool_pool - ) - await alist(stream) - - tru_results = tool_results - exp_results = [ - { - "content": [ - { - "text": "test result", - }, - ], - "status": "success", - "toolUseId": "t1", - }, - ] - - assert tru_results == exp_results - - -def test_validate_and_prepare_tools(): - message: Message = { - "role": "assistant", - "content": [ - {"text": "value"}, - {"toolUse": {"toolUseId": "t1", "name": "test_tool", "input": {"key": "value"}}}, - {"toolUse": {"toolUseId": "t2-invalid"}}, - ], - } - - tool_uses = [] - tool_results = [] - invalid_tool_use_ids = [] - - strands.tools.executor.validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) - - tru_tool_uses, tru_tool_results, tru_invalid_tool_use_ids = tool_uses, tool_results, invalid_tool_use_ids - exp_tool_uses = [ - { - "input": { - "key": "value", - }, - "name": "test_tool", - "toolUseId": "t1", - }, - { - "name": "INVALID_TOOL_NAME", - "toolUseId": "t2-invalid", - }, - ] - exp_tool_results = [ - { - "content": [ - { - "text": "Error: tool name missing", - }, - ], - "status": "error", - "toolUseId": "t2-invalid", - }, - ] - exp_invalid_tool_use_ids = ["t2-invalid"] - - assert tru_tool_uses == exp_tool_uses - assert tru_tool_results == exp_tool_results - assert tru_invalid_tool_use_ids == exp_invalid_tool_use_ids - - -@unittest.mock.patch("strands.tools.executor.get_tracer") -@pytest.mark.asyncio -async def test_run_tools_creates_and_ends_span_on_success( - mock_get_tracer, - tool_handler, - tool_uses, - mock_metrics_client, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - """Test that run_tools creates and ends a span on successful execution.""" - # Setup mock tracer and span - mock_tracer = unittest.mock.MagicMock() - mock_span = unittest.mock.MagicMock() - mock_tracer.start_tool_call_span.return_value = mock_span - mock_get_tracer.return_value = mock_tracer - - # Setup mock parent span - parent_span = unittest.mock.MagicMock() - - tool_results = [] - - # Run the tool - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - parent_span, - ) - await alist(stream) - - # Verify span was created with the parent span - mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span) - - # Verify span was ended with the tool result - mock_tracer.end_tool_call_span.assert_called_once() - args, _ = mock_tracer.end_tool_call_span.call_args - assert args[0] == mock_span - assert args[1]["status"] == "success" - assert args[1]["content"][0]["text"] == "test result" - - -@unittest.mock.patch("strands.tools.executor.get_tracer") -@pytest.mark.parametrize("tool_handler", [{"status": "failed"}], indirect=True) -@pytest.mark.asyncio -async def test_run_tools_creates_and_ends_span_on_failure( - mock_get_tracer, - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - """Test that run_tools creates and ends a span on tool failure.""" - # Setup mock tracer and span - mock_tracer = unittest.mock.MagicMock() - mock_span = unittest.mock.MagicMock() - mock_tracer.start_tool_call_span.return_value = mock_span - mock_get_tracer.return_value = mock_tracer - - # Setup mock parent span - parent_span = unittest.mock.MagicMock() - - tool_results = [] - - # Run the tool - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - parent_span, - ) - await alist(stream) - - # Verify span was created with the parent span - mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span) - - # Verify span was ended with the tool result - mock_tracer.end_tool_call_span.assert_called_once() - args, _ = mock_tracer.end_tool_call_span.call_args - assert args[0] == mock_span - assert args[1]["status"] == "failed" - - -@unittest.mock.patch("strands.tools.executor.get_tracer") -@pytest.mark.parametrize( - ("tool_uses", "invalid_tool_use_ids"), - [ - ( - [ - { - "toolUseId": "t1", - "name": "test_tool_success", - "input": {"key": "value1"}, - }, - { - "toolUseId": "t2", - "name": "test_tool_also_success", - "input": {"key": "value2"}, - }, - ], - [], - ), - ], - indirect=True, -) -@pytest.mark.asyncio -async def test_run_tools_concurrent_execution_with_spans( - mock_get_tracer, - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - cycle_trace, - alist, -): - # Setup mock tracer and spans - mock_tracer = unittest.mock.MagicMock() - mock_span1 = unittest.mock.MagicMock() - mock_span2 = unittest.mock.MagicMock() - mock_tracer.start_tool_call_span.side_effect = [mock_span1, mock_span2] - mock_get_tracer.return_value = mock_tracer - - # Setup mock parent span - parent_span = unittest.mock.MagicMock() - - tool_results = [] - - # Run the tools - stream = strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - invalid_tool_use_ids, - tool_results, - cycle_trace, - parent_span, - ) - await alist(stream) - - # Verify spans were created for both tools - assert mock_tracer.start_tool_call_span.call_count == 2 - mock_tracer.start_tool_call_span.assert_has_calls( - [ - unittest.mock.call(tool_uses[0], parent_span), - unittest.mock.call(tool_uses[1], parent_span), - ], - any_order=True, - ) - - # Verify spans were ended for both tools - assert mock_tracer.end_tool_call_span.call_count == 2 diff --git a/tests/strands/tools/test_validator.py b/tests/strands/tools/test_validator.py new file mode 100644 index 000000000..46e5e15f3 --- /dev/null +++ b/tests/strands/tools/test_validator.py @@ -0,0 +1,50 @@ +from strands.tools import _validator +from strands.types.content import Message + + +def test_validate_and_prepare_tools(): + message: Message = { + "role": "assistant", + "content": [ + {"text": "value"}, + {"toolUse": {"toolUseId": "t1", "name": "test_tool", "input": {"key": "value"}}}, + {"toolUse": {"toolUseId": "t2-invalid"}}, + ], + } + + tool_uses = [] + tool_results = [] + invalid_tool_use_ids = [] + + _validator.validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) + + tru_tool_uses, tru_tool_results, tru_invalid_tool_use_ids = tool_uses, tool_results, invalid_tool_use_ids + exp_tool_uses = [ + { + "input": { + "key": "value", + }, + "name": "test_tool", + "toolUseId": "t1", + }, + { + "name": "INVALID_TOOL_NAME", + "toolUseId": "t2-invalid", + }, + ] + exp_tool_results = [ + { + "content": [ + { + "text": "Error: tool name missing", + }, + ], + "status": "error", + "toolUseId": "t2-invalid", + }, + ] + exp_invalid_tool_use_ids = ["t2-invalid"] + + assert tru_tool_uses == exp_tool_uses + assert tru_tool_results == exp_tool_results + assert tru_invalid_tool_use_ids == exp_invalid_tool_use_ids diff --git a/tests_integ/tools/executors/test_concurrent.py b/tests_integ/tools/executors/test_concurrent.py new file mode 100644 index 000000000..27dd468e0 --- /dev/null +++ b/tests_integ/tools/executors/test_concurrent.py @@ -0,0 +1,61 @@ +import asyncio + +import pytest + +import strands +from strands import Agent +from strands.tools.executors import ConcurrentToolExecutor + + +@pytest.fixture +def tool_executor(): + return ConcurrentToolExecutor() + + +@pytest.fixture +def tool_events(): + return [] + + +@pytest.fixture +def time_tool(tool_events): + @strands.tool(name="time_tool") + async def func(): + tool_events.append({"name": "time_tool", "event": "start"}) + await asyncio.sleep(2) + tool_events.append({"name": "time_tool", "event": "end"}) + return "12:00" + + return func + + +@pytest.fixture +def weather_tool(tool_events): + @strands.tool(name="weather_tool") + async def func(): + tool_events.append({"name": "weather_tool", "event": "start"}) + await asyncio.sleep(1) + tool_events.append({"name": "weather_tool", "event": "end"}) + + return "sunny" + + return func + + +@pytest.fixture +def agent(tool_executor, time_tool, weather_tool): + return Agent(tools=[time_tool, weather_tool], tool_executor=tool_executor) + + +@pytest.mark.asyncio +async def test_agent_invoke_async_tool_executor(agent, tool_events): + await agent.invoke_async("What is the time and weather in New York?") + + tru_events = tool_events + exp_events = [ + {"name": "time_tool", "event": "start"}, + {"name": "weather_tool", "event": "start"}, + {"name": "weather_tool", "event": "end"}, + {"name": "time_tool", "event": "end"}, + ] + assert tru_events == exp_events diff --git a/tests_integ/tools/executors/test_sequential.py b/tests_integ/tools/executors/test_sequential.py new file mode 100644 index 000000000..82fc51a59 --- /dev/null +++ b/tests_integ/tools/executors/test_sequential.py @@ -0,0 +1,61 @@ +import asyncio + +import pytest + +import strands +from strands import Agent +from strands.tools.executors import SequentialToolExecutor + + +@pytest.fixture +def tool_executor(): + return SequentialToolExecutor() + + +@pytest.fixture +def tool_events(): + return [] + + +@pytest.fixture +def time_tool(tool_events): + @strands.tool(name="time_tool") + async def func(): + tool_events.append({"name": "time_tool", "event": "start"}) + await asyncio.sleep(2) + tool_events.append({"name": "time_tool", "event": "end"}) + return "12:00" + + return func + + +@pytest.fixture +def weather_tool(tool_events): + @strands.tool(name="weather_tool") + async def func(): + tool_events.append({"name": "weather_tool", "event": "start"}) + await asyncio.sleep(1) + tool_events.append({"name": "weather_tool", "event": "end"}) + + return "sunny" + + return func + + +@pytest.fixture +def agent(tool_executor, time_tool, weather_tool): + return Agent(tools=[time_tool, weather_tool], tool_executor=tool_executor) + + +@pytest.mark.asyncio +async def test_agent_invoke_async_tool_executor(agent, tool_events): + await agent.invoke_async("What is the time and weather in New York?") + + tru_events = tool_events + exp_events = [ + {"name": "time_tool", "event": "start"}, + {"name": "time_tool", "event": "end"}, + {"name": "weather_tool", "event": "start"}, + {"name": "weather_tool", "event": "end"}, + ] + assert tru_events == exp_events From dbe0fea146749f578bfd73dae22182d69df70a7e Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Mon, 25 Aug 2025 11:06:43 -0400 Subject: [PATCH 050/104] feat: Add support for agent invoke with no input, or Message input (#653) --- src/strands/agent/agent.py | 122 ++++++++++++++++++------- src/strands/telemetry/tracer.py | 17 ++-- tests/strands/agent/test_agent.py | 82 ++++++++++++++--- tests/strands/telemetry/test_tracer.py | 2 +- 4 files changed, 168 insertions(+), 55 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index adc554bf4..654b8edce 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -361,14 +361,21 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) - def __call__(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult: + def __call__(self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any) -> AgentResult: """Process a natural language prompt through the agent's event loop. - This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to - the conversation history, processes it through the model, executes any tool calls, and returns the final result. + This method implements the conversational interface with multiple input patterns: + - String input: `agent("hello!")` + - ContentBlock list: `agent([{"text": "hello"}, {"image": {...}}])` + - Message list: `agent([{"role": "user", "content": [{"text": "hello"}]}])` + - No input: `agent()` - uses existing conversation history Args: - prompt: User input as text or list of ContentBlock objects for multi-modal content. + prompt: User input in various formats: + - str: Simple text input + - list[ContentBlock]: Multi-modal content blocks + - list[Message]: Complete messages with roles + - None: Use existing conversation history **kwargs: Additional parameters to pass through the event loop. Returns: @@ -387,14 +394,23 @@ def execute() -> AgentResult: future = executor.submit(execute) return future.result() - async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult: + async def invoke_async( + self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any + ) -> AgentResult: """Process a natural language prompt through the agent's event loop. - This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to - the conversation history, processes it through the model, executes any tool calls, and returns the final result. + This method implements the conversational interface with multiple input patterns: + - String input: Simple text input + - ContentBlock list: Multi-modal content blocks + - Message list: Complete messages with roles + - No input: Use existing conversation history Args: - prompt: User input as text or list of ContentBlock objects for multi-modal content. + prompt: User input in various formats: + - str: Simple text input + - list[ContentBlock]: Multi-modal content blocks + - list[Message]: Complete messages with roles + - None: Use existing conversation history **kwargs: Additional parameters to pass through the event loop. Returns: @@ -411,7 +427,7 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A return cast(AgentResult, event["result"]) - def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T: + def structured_output(self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -423,7 +439,11 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l Args: output_model: The output model (a JSON schema written as a Pydantic BaseModel) that the agent will use when responding. - prompt: The prompt to use for the agent (will not be added to conversation history). + prompt: The prompt to use for the agent in various formats: + - str: Simple text input + - list[ContentBlock]: Multi-modal content blocks + - list[Message]: Complete messages with roles + - None: Use existing conversation history Raises: ValueError: If no conversation history or prompt is provided. @@ -437,7 +457,7 @@ def execute() -> T: return future.result() async def structured_output_async( - self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None + self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None ) -> T: """This method allows you to get structured output from the agent. @@ -462,12 +482,8 @@ async def structured_output_async( try: if not self.messages and not prompt: raise ValueError("No conversation history or prompt provided") - # Create temporary messages array if prompt is provided - if prompt: - content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt - temp_messages = self.messages + [{"role": "user", "content": content}] - else: - temp_messages = self.messages + + temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt) structured_output_span.set_attributes( { @@ -499,16 +515,25 @@ async def structured_output_async( finally: self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) - async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]: + async def stream_async( + self, + prompt: str | list[ContentBlock] | Messages | None = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. - This method provides an asynchronous interface for streaming agent events, allowing - consumers to process stream events programmatically through an async iterator pattern - rather than callback functions. This is particularly useful for web servers and other - async environments. + This method provides an asynchronous interface for streaming agent events with multiple input patterns: + - String input: Simple text input + - ContentBlock list: Multi-modal content blocks + - Message list: Complete messages with roles + - No input: Use existing conversation history Args: - prompt: User input as text or list of ContentBlock objects for multi-modal content. + prompt: User input in various formats: + - str: Simple text input + - list[ContentBlock]: Multi-modal content blocks + - list[Message]: Complete messages with roles + - None: Use existing conversation history **kwargs: Additional parameters to pass to the event loop. Yields: @@ -532,13 +557,15 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A """ callback_handler = kwargs.get("callback_handler", self.callback_handler) - content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt - message: Message = {"role": "user", "content": content} + # Process input and get message to add (if any) + messages = self._convert_prompt_to_messages(prompt) + + self.trace_span = self._start_agent_trace_span(messages) - self.trace_span = self._start_agent_trace_span(message) with trace_api.use_span(self.trace_span): try: - events = self._run_loop(message, invocation_state=kwargs) + events = self._run_loop(messages, invocation_state=kwargs) + async for event in events: if "callback" in event: callback_handler(**event["callback"]) @@ -555,12 +582,12 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A raise async def _run_loop( - self, message: Message, invocation_state: dict[str, Any] + self, messages: Messages, invocation_state: dict[str, Any] ) -> AsyncGenerator[dict[str, Any], None]: """Execute the agent's event loop with the given message and parameters. Args: - message: The user message to add to the conversation. + messages: The input messages to add to the conversation. invocation_state: Additional parameters to pass to the event loop. Yields: @@ -571,7 +598,8 @@ async def _run_loop( try: yield {"callback": {"init_event_loop": True, **invocation_state}} - self._append_message(message) + for message in messages: + self._append_message(message) # Execute the event loop cycle with retry logic for context limits events = self._execute_event_loop_cycle(invocation_state) @@ -629,6 +657,34 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A async for event in events: yield event + def _convert_prompt_to_messages(self, prompt: str | list[ContentBlock] | Messages | None) -> Messages: + messages: Messages | None = None + if prompt is not None: + if isinstance(prompt, str): + # String input - convert to user message + messages = [{"role": "user", "content": [{"text": prompt}]}] + elif isinstance(prompt, list): + if len(prompt) == 0: + # Empty list + messages = [] + # Check if all item in input list are dictionaries + elif all(isinstance(item, dict) for item in prompt): + # Check if all items are messages + if all(all(key in item for key in Message.__annotations__.keys()) for item in prompt): + # Messages input - add all messages to conversation + messages = cast(Messages, prompt) + + # Check if all items are content blocks + elif all(any(key in ContentBlock.__annotations__.keys() for key in item) for item in prompt): + # Treat as List[ContentBlock] input - convert to user message + # This allows invalid structures to be passed through to the model + messages = [{"role": "user", "content": cast(list[ContentBlock], prompt)}] + else: + messages = [] + if messages is None: + raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.") + return messages + def _record_tool_execution( self, tool: ToolUse, @@ -694,15 +750,15 @@ def _record_tool_execution( self._append_message(tool_result_msg) self._append_message(assistant_msg) - def _start_agent_trace_span(self, message: Message) -> trace_api.Span: + def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: """Starts a trace span for the agent. Args: - message: The user message. + messages: The input messages. """ model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None return self.tracer.start_agent_span( - message=message, + messages=messages, agent_name=self.name, model_id=model_id, tools=self.tool_names, diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 802865189..6b429393d 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -408,7 +408,7 @@ def end_event_loop_cycle_span( def start_agent_span( self, - message: Message, + messages: Messages, agent_name: str, model_id: Optional[str] = None, tools: Optional[list] = None, @@ -418,7 +418,7 @@ def start_agent_span( """Start a new span for an agent invocation. Args: - message: The user message being sent to the agent. + messages: List of messages being sent to the agent. agent_name: Name of the agent. model_id: Optional model identifier. tools: Optional list of tools being used. @@ -451,13 +451,12 @@ def start_agent_span( span = self._start_span( f"invoke_agent {agent_name}", attributes=attributes, span_kind=trace_api.SpanKind.CLIENT ) - self._add_event( - span, - "gen_ai.user.message", - event_attributes={ - "content": serialize(message["content"]), - }, - ) + for message in messages: + self._add_event( + span, + f"gen_ai.{message['role']}.message", + {"content": serialize(message["content"])}, + ) return span diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 279e2a06e..01d8f977e 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1332,12 +1332,12 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( + messages=[{"content": [{"text": "test prompt"}], "role": "user"}], agent_name="Strands Agents", - custom_trace_attributes=agent.trace_attributes, - message={"content": [{"text": "test prompt"}], "role": "user"}, model_id=unittest.mock.ANY, - system_prompt=agent.system_prompt, tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, ) # Verify span was ended with the result @@ -1366,12 +1366,12 @@ async def test_event_loop(*args, **kwargs): # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( - custom_trace_attributes=agent.trace_attributes, + messages=[{"content": [{"text": "test prompt"}], "role": "user"}], agent_name="Strands Agents", - message={"content": [{"text": "test prompt"}], "role": "user"}, model_id=unittest.mock.ANY, - system_prompt=agent.system_prompt, tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, ) expected_response = AgentResult( @@ -1404,12 +1404,12 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( - custom_trace_attributes=agent.trace_attributes, + messages=[{"content": [{"text": "test prompt"}], "role": "user"}], agent_name="Strands Agents", - message={"content": [{"text": "test prompt"}], "role": "user"}, model_id=unittest.mock.ANY, - system_prompt=agent.system_prompt, tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, ) # Verify span was ended with the exception @@ -1440,12 +1440,12 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( + messages=[{"content": [{"text": "test prompt"}], "role": "user"}], agent_name="Strands Agents", - custom_trace_attributes=agent.trace_attributes, - message={"content": [{"text": "test prompt"}], "role": "user"}, model_id=unittest.mock.ANY, - system_prompt=agent.system_prompt, tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, ) # Verify span was ended with the exception @@ -1773,6 +1773,63 @@ def test_agent_tool_record_direct_tool_call_disabled_with_non_serializable(agent assert len(agent.messages) == 0 +def test_agent_empty_invoke(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) + agent = Agent(model=model, messages=[{"role": "user", "content": [{"text": "hello!"}]}]) + result = agent() + assert str(result) == "hello!\n" + assert len(agent.messages) == 2 + + +def test_agent_empty_list_invoke(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) + agent = Agent(model=model, messages=[{"role": "user", "content": [{"text": "hello!"}]}]) + result = agent([]) + assert str(result) == "hello!\n" + assert len(agent.messages) == 2 + + +def test_agent_with_assistant_role_message(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}]) + agent = Agent(model=model) + assistant_message = [{"role": "assistant", "content": [{"text": "hello..."}]}] + result = agent(assistant_message) + assert str(result) == "world!\n" + assert len(agent.messages) == 2 + + +def test_agent_with_multiple_messages_on_invoke(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}]) + agent = Agent(model=model) + input_messages = [ + {"role": "user", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "..."}]}, + ] + result = agent(input_messages) + assert str(result) == "world!\n" + assert len(agent.messages) == 3 + + +def test_agent_with_invalid_input(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}]) + agent = Agent(model=model) + with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."): + agent({"invalid": "input"}) + + +def test_agent_with_invalid_input_list(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}]) + agent = Agent(model=model) + with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."): + agent([{"invalid": "input"}]) + + +def test_agent_with_list_of_message_and_content_block(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}]) + agent = Agent(model=model) + with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."): + agent([{"role": "user", "content": [{"text": "hello"}]}, {"text", "hello"}]) + def test_agent_tool_call_parameter_filtering_integration(mock_randint): """Test that tool calls properly filter parameters in message recording.""" mock_randint.return_value = 42 @@ -1804,3 +1861,4 @@ def test_tool(action: str) -> str: assert '"action": "test_value"' in tool_call_text assert '"agent"' not in tool_call_text assert '"extra_param"' not in tool_call_text + diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index dcfce1211..586911bef 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -369,7 +369,7 @@ def test_start_agent_span(mock_tracer): span = tracer.start_agent_span( custom_trace_attributes=custom_attrs, agent_name="WeatherAgent", - message={"content": content, "role": "user"}, + messages=[{"content": content, "role": "user"}], model_id=model_id, tools=tools, ) From b156ea68c824fdb968d4d986a835878b0bfc1b93 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 25 Aug 2025 11:06:54 -0400 Subject: [PATCH 051/104] ci: bump actions/checkout from 4 to 5 (#711) Bumps [actions/checkout](https://github.com/actions/checkout) from 4 to 5. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/integration-test.yml | 2 +- .github/workflows/pypi-publish-on-release.yml | 2 +- .github/workflows/test-lint.yml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index c347e3805..d410bb712 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -52,7 +52,7 @@ jobs: aws-region: us-east-1 mask-aws-account-id: true - name: Checkout head commit - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo persist-credentials: false # Don't persist credentials for subsequent actions diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index 8967c5524..e3c5385a7 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: persist-credentials: false diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 35e0f5841..c0ed4faca 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -51,7 +51,7 @@ jobs: LOG_LEVEL: DEBUG steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: ref: ${{ inputs.ref }} # Explicitly define which commit to check out persist-credentials: false # Don't persist credentials for subsequent actions @@ -73,7 +73,7 @@ jobs: contents: read steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: ref: ${{ inputs.ref }} persist-credentials: false From 0283169c7a3e424494e6260d163324f75eeeb8f7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 25 Aug 2025 11:07:08 -0400 Subject: [PATCH 052/104] ci: bump actions/download-artifact from 4 to 5 (#712) Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 4 to 5. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/pypi-publish-on-release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index e3c5385a7..c2420d747 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -74,7 +74,7 @@ jobs: steps: - name: Download all the dists - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: python-package-distributions path: dist/ From e5e308ff794d02eca035a96e148478ed14747ea9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 25 Aug 2025 11:25:31 -0400 Subject: [PATCH 053/104] ci: update pytest-cov requirement from <5.0.0,>=4.1.0 to >=4.1.0,<7.0.0 (#705) --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 32de94aa6..8a95ba04c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,8 +57,8 @@ dev = [ "mypy>=1.15.0,<2.0.0", "pre-commit>=3.2.0,<4.4.0", "pytest>=8.0.0,<9.0.0", + "pytest-cov>=6.0.0,<7.0.0", "pytest-asyncio>=1.0.0,<1.2.0", - "pytest-cov>=4.1.0,<5.0.0", "pytest-xdist>=3.0.0,<4.0.0", "ruff>=0.12.0,<0.13.0", ] @@ -143,8 +143,8 @@ features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mis extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", + "pytest-cov>=6.0.0,<7.0.0", "pytest-asyncio>=1.0.0,<1.2.0", - "pytest-cov>=4.1.0,<5.0.0", "pytest-xdist>=3.0.0,<4.0.0", ] extra-args = [ From 918f0945ea9dd0c786bba9af814268d1387a818b Mon Sep 17 00:00:00 2001 From: mehtarac Date: Mon, 25 Aug 2025 08:34:11 -0700 Subject: [PATCH 054/104] fix: prevent path traversal for message_id in file_session_manager (#728) * fix: prevent path traversal for message_id in file_session_manager * fix: prevent path traversal for message_id in session managers * fix: prevent path traversal for message_id in session managers --- src/strands/session/file_session_manager.py | 6 +++++ src/strands/session/s3_session_manager.py | 7 +++++- .../session/test_file_session_manager.py | 22 +++++++++++++++++-- .../session/test_s3_session_manager.py | 20 ++++++++++++++++- 4 files changed, 51 insertions(+), 4 deletions(-) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 9df86e17a..14e71d07c 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -86,7 +86,13 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> message_id: Index of the message Returns: The filename for the message + + Raises: + ValueError: If message_id is not an integer. """ + if not isinstance(message_id, int): + raise ValueError(f"message_id=<{message_id}> | message id must be an integer") + agent_path = self._get_agent_path(session_id, agent_id) return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json") diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index d15e6e3bd..da1735e35 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -113,11 +113,16 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> session_id: ID of the session agent_id: ID of the agent message_id: Index of the message - **kwargs: Additional keyword arguments for future extensibility. Returns: The key for the message + + Raises: + ValueError: If message_id is not an integer. """ + if not isinstance(message_id, int): + raise ValueError(f"message_id=<{message_id}> | message id must be an integer") + agent_path = self._get_agent_path(session_id, agent_id) return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json" diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index a89222b7e..036591924 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -224,14 +224,14 @@ def test_read_messages_with_new_agent(file_manager, sample_session, sample_agent file_manager.create_session(sample_session) file_manager.create_agent(sample_session.session_id, sample_agent) - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999) assert result is None def test_read_nonexistent_message(file_manager, sample_session, sample_agent): """Test reading a message that doesnt exist.""" - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999) assert result is None @@ -390,3 +390,21 @@ def test__get_session_path_invalid_session_id(session_id, file_manager): def test__get_agent_path_invalid_agent_id(agent_id, file_manager): with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): file_manager._get_agent_path("session1", agent_id) + + +@pytest.mark.parametrize( + "message_id", + [ + "../../../secret", + "../../attack", + "../escape", + "path/traversal", + "not_an_int", + None, + [], + ], +) +def test__get_message_path_invalid_message_id(message_id, file_manager): + """Test that message_id that is not an integer raises ValueError.""" + with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): + file_manager._get_message_path("session1", "agent1", message_id) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 71bff3050..50fb303f7 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -251,7 +251,7 @@ def test_read_nonexistent_message(s3_manager, sample_session, sample_agent, samp s3_manager.create_agent(sample_session.session_id, sample_agent) # Read message - result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999) assert result is None @@ -356,3 +356,21 @@ def test__get_session_path_invalid_session_id(session_id, s3_manager): def test__get_agent_path_invalid_agent_id(agent_id, s3_manager): with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): s3_manager._get_agent_path("session1", agent_id) + + +@pytest.mark.parametrize( + "message_id", + [ + "../../../secret", + "../../attack", + "../escape", + "path/traversal", + "not_an_int", + None, + [], + ], +) +def test__get_message_path_invalid_message_id(message_id, s3_manager): + """Test that message_id that is not an integer raises ValueError.""" + with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): + s3_manager._get_message_path("session1", "agent1", message_id) From f028dc96df64d97f1f5a05be9ec2fc7cd8467a8d Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Mon, 25 Aug 2025 13:31:42 -0400 Subject: [PATCH 055/104] fix: Add AgentInput TypeAlias (#738) --- CONTRIBUTING.md | 2 +- src/strands/agent/agent.py | 32 ++++++++++++------- src/strands/session/file_session_manager.py | 2 +- src/strands/session/s3_session_manager.py | 4 +-- tests/strands/agent/test_agent.py | 2 +- .../session/test_file_session_manager.py | 2 +- .../session/test_s3_session_manager.py | 2 +- 7 files changed, 28 insertions(+), 18 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index add4825fd..93970ed64 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -49,7 +49,7 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as Alternatively, install development dependencies in a manually created virtual environment: ```bash - pip install -e ".[dev]" && pip install -e ".[litellm]" + pip install -e ".[all]" ``` diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 654b8edce..66099cb1d 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -14,7 +14,19 @@ import logging import random from concurrent.futures import ThreadPoolExecutor -from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Callable, + Mapping, + Optional, + Type, + TypeAlias, + TypeVar, + Union, + cast, +) from opentelemetry import trace as trace_api from pydantic import BaseModel @@ -55,6 +67,8 @@ # TypeVar for generic structured output T = TypeVar("T", bound=BaseModel) +AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None + # Sentinel class and object to distinguish between explicit None and default parameter value class _DefaultCallbackHandlerSentinel: @@ -361,7 +375,7 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) - def __call__(self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any) -> AgentResult: + def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface with multiple input patterns: @@ -394,9 +408,7 @@ def execute() -> AgentResult: future = executor.submit(execute) return future.result() - async def invoke_async( - self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any - ) -> AgentResult: + async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface with multiple input patterns: @@ -427,7 +439,7 @@ async def invoke_async( return cast(AgentResult, event["result"]) - def structured_output(self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None) -> T: + def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -456,9 +468,7 @@ def execute() -> T: future = executor.submit(execute) return future.result() - async def structured_output_async( - self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None - ) -> T: + async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -517,7 +527,7 @@ async def structured_output_async( async def stream_async( self, - prompt: str | list[ContentBlock] | Messages | None = None, + prompt: AgentInput = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -657,7 +667,7 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A async for event in events: yield event - def _convert_prompt_to_messages(self, prompt: str | list[ContentBlock] | Messages | None) -> Messages: + def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: messages: Messages | None = None if prompt is not None: if isinstance(prompt, str): diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 14e71d07c..491f7ad60 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -92,7 +92,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> """ if not isinstance(message_id, int): raise ValueError(f"message_id=<{message_id}> | message id must be an integer") - + agent_path = self._get_agent_path(session_id, agent_id) return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json") diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index da1735e35..c6ce28d80 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -116,13 +116,13 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> Returns: The key for the message - + Raises: ValueError: If message_id is not an integer. """ if not isinstance(message_id, int): raise ValueError(f"message_id=<{message_id}> | message id must be an integer") - + agent_path = self._get_agent_path(session_id, agent_id) return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 01d8f977e..67ea5940a 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1830,6 +1830,7 @@ def test_agent_with_list_of_message_and_content_block(): with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."): agent([{"role": "user", "content": [{"text": "hello"}]}, {"text", "hello"}]) + def test_agent_tool_call_parameter_filtering_integration(mock_randint): """Test that tool calls properly filter parameters in message recording.""" mock_randint.return_value = 42 @@ -1861,4 +1862,3 @@ def test_tool(action: str) -> str: assert '"action": "test_value"' in tool_call_text assert '"agent"' not in tool_call_text assert '"extra_param"' not in tool_call_text - diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index 036591924..f124ddf58 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -396,7 +396,7 @@ def test__get_agent_path_invalid_agent_id(agent_id, file_manager): "message_id", [ "../../../secret", - "../../attack", + "../../attack", "../escape", "path/traversal", "not_an_int", diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 50fb303f7..c4d6a0154 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -362,7 +362,7 @@ def test__get_agent_path_invalid_agent_id(agent_id, s3_manager): "message_id", [ "../../../secret", - "../../attack", + "../../attack", "../escape", "path/traversal", "not_an_int", From 0fac6480b5d64bf4500c0ea257e0d237a639cd64 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 26 Aug 2025 10:39:56 -0400 Subject: [PATCH 056/104] fix: Move AgentInput to types submodule (#746) --- src/strands/agent/agent.py | 4 +--- src/strands/types/agent.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) create mode 100644 src/strands/types/agent.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 66099cb1d..e2aed9d2b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -22,7 +22,6 @@ Mapping, Optional, Type, - TypeAlias, TypeVar, Union, cast, @@ -51,6 +50,7 @@ from ..tools.executors._executor import ToolExecutor from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher +from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException from ..types.tools import ToolResult, ToolUse @@ -67,8 +67,6 @@ # TypeVar for generic structured output T = TypeVar("T", bound=BaseModel) -AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None - # Sentinel class and object to distinguish between explicit None and default parameter value class _DefaultCallbackHandlerSentinel: diff --git a/src/strands/types/agent.py b/src/strands/types/agent.py new file mode 100644 index 000000000..151c88f89 --- /dev/null +++ b/src/strands/types/agent.py @@ -0,0 +1,10 @@ +"""Agent-related type definitions for the SDK. + +This module defines the types used for an Agent. +""" + +from typing import TypeAlias + +from .content import ContentBlock, Messages + +AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None From aa03b3dfffbc98303bba8f57a19e98b1bdb239af Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 27 Aug 2025 09:14:45 -0400 Subject: [PATCH 057/104] feat: Implement typed events internally (#745) Step 1/N for implementing typed-events; first just preserve the existing behaviors with no changes to the public api. A follow-up change will update how we invoke callbacks and pass invocation_state around, while this one just adds typed classes for events internally. --------- Co-authored-by: Mackenzie Zastrow --- src/strands/agent/agent.py | 3 +- src/strands/event_loop/event_loop.py | 35 ++- src/strands/event_loop/streaming.py | 50 ++-- src/strands/types/_events.py | 238 ++++++++++++++++++ .../strands/agent/hooks/test_agent_events.py | 159 ++++++++++++ tests/strands/agent/test_agent.py | 129 +++++----- tests/strands/event_loop/test_streaming.py | 9 + 7 files changed, 529 insertions(+), 94 deletions(-) create mode 100644 src/strands/types/_events.py create mode 100644 tests/strands/agent/hooks/test_agent_events.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e2aed9d2b..8233c4bfe 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -50,6 +50,7 @@ from ..tools.executors._executor import ToolExecutor from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher +from ..types._events import InitEventLoopEvent from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException @@ -604,7 +605,7 @@ async def _run_loop( self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) try: - yield {"callback": {"init_event_loop": True, **invocation_state}} + yield InitEventLoopEvent(invocation_state) for message in messages: self._append_message(message) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 524ecc3e8..a166902eb 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -25,6 +25,15 @@ from ..telemetry.metrics import Trace from ..telemetry.tracer import get_tracer from ..tools._validator import validate_and_prepare_tools +from ..types._events import ( + EventLoopStopEvent, + EventLoopThrottleEvent, + ForceStopEvent, + ModelMessageEvent, + StartEvent, + StartEventLoopEvent, + ToolResultMessageEvent, +) from ..types.content import Message from ..types.exceptions import ( ContextWindowOverflowException, @@ -91,8 +100,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes) invocation_state["event_loop_cycle_trace"] = cycle_trace - yield {"callback": {"start": True}} - yield {"callback": {"start_event_loop": True}} + yield StartEvent() + yield StartEventLoopEvent() # Create tracer span for this event loop cycle tracer = get_tracer() @@ -175,7 +184,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> if isinstance(e, ModelThrottledException): if attempt + 1 == MAX_ATTEMPTS: - yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}} + yield ForceStopEvent(reason=e) raise e logger.debug( @@ -189,7 +198,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> time.sleep(current_delay) current_delay = min(current_delay * 2, MAX_DELAY) - yield {"callback": {"event_loop_throttled_delay": current_delay, **invocation_state}} + yield EventLoopThrottleEvent(delay=current_delay, invocation_state=invocation_state) else: raise e @@ -201,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # Add the response message to the conversation agent.messages.append(message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) - yield {"callback": {"message": message}} + yield ModelMessageEvent(message=message) # Update metrics agent.event_loop_metrics.update_usage(usage) @@ -235,8 +244,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> cycle_start_time=cycle_start_time, invocation_state=invocation_state, ) - async for event in events: - yield event + async for typed_event in events: + yield typed_event return @@ -264,11 +273,11 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> tracer.end_span_with_error(cycle_span, str(e), e) # Handle any other exceptions - yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}} + yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e - yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: @@ -295,7 +304,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id) cycle_trace.add_child(recursive_trace) - yield {"callback": {"start": True}} + yield StartEvent() events = event_loop_cycle(agent=agent, invocation_state=invocation_state) async for event in events: @@ -339,7 +348,7 @@ async def _handle_tool_execution( validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] if not tool_uses: - yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) return tool_events = agent.tool_executor._execute( @@ -358,7 +367,7 @@ async def _handle_tool_execution( agent.messages.append(tool_result_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) - yield {"callback": {"message": tool_result_message}} + yield ToolResultMessageEvent(message=message) if cycle_span: tracer = get_tracer() @@ -366,7 +375,7 @@ async def _handle_tool_execution( if invocation_state["request_state"].get("stop_event_loop", False): agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) - yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) return events = recurse_event_loop(agent=agent, invocation_state=invocation_state) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 1f8c260a4..7507c6d75 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -5,6 +5,16 @@ from typing import Any, AsyncGenerator, AsyncIterable, Optional from ..models.model import Model +from ..types._events import ( + ModelStopReason, + ModelStreamChunkEvent, + ModelStreamEvent, + ReasoningSignatureStreamEvent, + ReasoningTextStreamEvent, + TextStreamEvent, + ToolUseStreamEvent, + TypedEvent, +) from ..types.content import ContentBlock, Message, Messages from ..types.streaming import ( ContentBlockDeltaEvent, @@ -105,7 +115,7 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: def handle_content_block_delta( event: ContentBlockDeltaEvent, state: dict[str, Any] -) -> tuple[dict[str, Any], dict[str, Any]]: +) -> tuple[dict[str, Any], ModelStreamEvent]: """Handles content block delta updates by appending text, tool input, or reasoning content to the state. Args: @@ -117,18 +127,18 @@ def handle_content_block_delta( """ delta_content = event["delta"] - callback_event = {} + typed_event: ModelStreamEvent = ModelStreamEvent({}) if "toolUse" in delta_content: if "input" not in state["current_tool_use"]: state["current_tool_use"]["input"] = "" state["current_tool_use"]["input"] += delta_content["toolUse"]["input"] - callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]} + typed_event = ToolUseStreamEvent(delta_content, state["current_tool_use"]) elif "text" in delta_content: state["text"] += delta_content["text"] - callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content} + typed_event = TextStreamEvent(text=delta_content["text"], delta=delta_content) elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: @@ -136,24 +146,22 @@ def handle_content_block_delta( state["reasoningText"] = "" state["reasoningText"] += delta_content["reasoningContent"]["text"] - callback_event["callback"] = { - "reasoningText": delta_content["reasoningContent"]["text"], - "delta": delta_content, - "reasoning": True, - } + typed_event = ReasoningTextStreamEvent( + reasoning_text=delta_content["reasoningContent"]["text"], + delta=delta_content, + ) elif "signature" in delta_content["reasoningContent"]: if "signature" not in state: state["signature"] = "" state["signature"] += delta_content["reasoningContent"]["signature"] - callback_event["callback"] = { - "reasoning_signature": delta_content["reasoningContent"]["signature"], - "delta": delta_content, - "reasoning": True, - } + typed_event = ReasoningSignatureStreamEvent( + reasoning_signature=delta_content["reasoningContent"]["signature"], + delta=delta_content, + ) - return state, callback_event + return state, typed_event def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: @@ -251,7 +259,7 @@ def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]: return usage, metrics -async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[dict[str, Any], None]: +async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[TypedEvent, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: @@ -274,14 +282,14 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d metrics: Metrics = Metrics(latencyMs=0) async for chunk in chunks: - yield {"callback": {"event": chunk}} + yield ModelStreamChunkEvent(chunk=chunk) if "messageStart" in chunk: state["message"] = handle_message_start(chunk["messageStart"], state["message"]) elif "contentBlockStart" in chunk: state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"]) elif "contentBlockDelta" in chunk: - state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state) - yield callback_event + state, typed_event = handle_content_block_delta(chunk["contentBlockDelta"], state) + yield typed_event elif "contentBlockStop" in chunk: state = handle_content_block_stop(state) elif "messageStop" in chunk: @@ -291,7 +299,7 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d elif "redactContent" in chunk: handle_redact_content(chunk["redactContent"], state) - yield {"stop": (stop_reason, state["message"], usage, metrics)} + yield ModelStopReason(stop_reason=stop_reason, message=state["message"], usage=usage, metrics=metrics) async def stream_messages( @@ -299,7 +307,7 @@ async def stream_messages( system_prompt: Optional[str], messages: Messages, tool_specs: list[ToolSpec], -) -> AsyncGenerator[dict[str, Any], None]: +) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. Args: diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py new file mode 100644 index 000000000..1bddc5877 --- /dev/null +++ b/src/strands/types/_events.py @@ -0,0 +1,238 @@ +"""event system for the Strands Agents framework. + +This module defines the event types that are emitted during agent execution, +providing a structured way to observe to different events of the event loop and +agent lifecycle. +""" + +from typing import TYPE_CHECKING, Any + +from ..telemetry import EventLoopMetrics +from .content import Message +from .event_loop import Metrics, StopReason, Usage +from .streaming import ContentBlockDelta, StreamEvent + +if TYPE_CHECKING: + pass + + +class TypedEvent(dict): + """Base class for all typed events in the agent system.""" + + def __init__(self, data: dict[str, Any] | None = None) -> None: + """Initialize the typed event with optional data. + + Args: + data: Optional dictionary of event data to initialize with + """ + super().__init__(data or {}) + + +class InitEventLoopEvent(TypedEvent): + """Event emitted at the very beginning of agent execution. + + This event is fired before any processing begins and provides access to the + initial invocation state. + + Args: + invocation_state: The invocation state passed into the request + """ + + def __init__(self, invocation_state: dict) -> None: + """Initialize the event loop initialization event.""" + super().__init__({"callback": {"init_event_loop": True, **invocation_state}}) + + +class StartEvent(TypedEvent): + """Event emitted at the start of each event loop cycle. + + !!deprecated!! + Use StartEventLoopEvent instead. + + This event events the beginning of a new processing cycle within the agent's + event loop. It's fired before model invocation and tool execution begin. + """ + + def __init__(self) -> None: + """Initialize the event loop start event.""" + super().__init__({"callback": {"start": True}}) + + +class StartEventLoopEvent(TypedEvent): + """Event emitted when the event loop cycle begins processing. + + This event is fired after StartEvent and indicates that the event loop + has begun its core processing logic, including model invocation preparation. + """ + + def __init__(self) -> None: + """Initialize the event loop processing start event.""" + super().__init__({"callback": {"start_event_loop": True}}) + + +class ModelStreamChunkEvent(TypedEvent): + """Event emitted during model response streaming for each raw chunk.""" + + def __init__(self, chunk: StreamEvent) -> None: + """Initialize with streaming delta data from the model. + + Args: + chunk: Incremental streaming data from the model response + """ + super().__init__({"callback": {"event": chunk}}) + + +class ModelStreamEvent(TypedEvent): + """Event emitted during model response streaming. + + This event is fired when the model produces streaming output during response + generation. + """ + + def __init__(self, delta_data: dict[str, Any]) -> None: + """Initialize with streaming delta data from the model. + + Args: + delta_data: Incremental streaming data from the model response + """ + super().__init__(delta_data) + + +class ToolUseStreamEvent(ModelStreamEvent): + """Event emitted during tool use input streaming.""" + + def __init__(self, delta: ContentBlockDelta, current_tool_use: dict[str, Any]) -> None: + """Initialize with delta and current tool use state.""" + super().__init__({"callback": {"delta": delta, "current_tool_use": current_tool_use}}) + + +class TextStreamEvent(ModelStreamEvent): + """Event emitted during text content streaming.""" + + def __init__(self, delta: ContentBlockDelta, text: str) -> None: + """Initialize with delta and text content.""" + super().__init__({"callback": {"data": text, "delta": delta}}) + + +class ReasoningTextStreamEvent(ModelStreamEvent): + """Event emitted during reasoning text streaming.""" + + def __init__(self, delta: ContentBlockDelta, reasoning_text: str | None) -> None: + """Initialize with delta and reasoning text.""" + super().__init__({"callback": {"reasoningText": reasoning_text, "delta": delta, "reasoning": True}}) + + +class ReasoningSignatureStreamEvent(ModelStreamEvent): + """Event emitted during reasoning signature streaming.""" + + def __init__(self, delta: ContentBlockDelta, reasoning_signature: str | None) -> None: + """Initialize with delta and reasoning signature.""" + super().__init__({"callback": {"reasoning_signature": reasoning_signature, "delta": delta, "reasoning": True}}) + + +class ModelStopReason(TypedEvent): + """Event emitted during reasoning signature streaming.""" + + def __init__( + self, + stop_reason: StopReason, + message: Message, + usage: Usage, + metrics: Metrics, + ) -> None: + """Initialize with the final execution results. + + Args: + stop_reason: Why the agent execution stopped + message: Final message from the model + usage: Usage information from the model + metrics: Execution metrics and performance data + """ + super().__init__({"stop": (stop_reason, message, usage, metrics)}) + + +class EventLoopStopEvent(TypedEvent): + """Event emitted when the agent execution completes normally.""" + + def __init__( + self, + stop_reason: StopReason, + message: Message, + metrics: "EventLoopMetrics", + request_state: Any, + ) -> None: + """Initialize with the final execution results. + + Args: + stop_reason: Why the agent execution stopped + message: Final message from the model + metrics: Execution metrics and performance data + request_state: Final state of the agent execution + """ + super().__init__({"stop": (stop_reason, message, metrics, request_state)}) + + +class EventLoopThrottleEvent(TypedEvent): + """Event emitted when the event loop is throttled due to rate limiting.""" + + def __init__(self, delay: int, invocation_state: dict[str, Any]) -> None: + """Initialize with the throttle delay duration. + + Args: + delay: Delay in seconds before the next retry attempt + invocation_state: The invocation state passed into the request + """ + super().__init__({"callback": {"event_loop_throttled_delay": delay, **invocation_state}}) + + +class ModelMessageEvent(TypedEvent): + """Event emitted when the model invocation has completed. + + This event is fired whenever the model generates a response message that + gets added to the conversation history. + """ + + def __init__(self, message: Message) -> None: + """Initialize with the model-generated message. + + Args: + message: The response message from the model + """ + super().__init__({"callback": {"message": message}}) + + +class ToolResultMessageEvent(TypedEvent): + """Event emitted when tool results are formatted as a message. + + This event is fired when tool execution results are converted into a + message format to be added to the conversation history. It provides + access to the formatted message containing tool results. + """ + + def __init__(self, message: Any) -> None: + """Initialize with the model-generated message. + + Args: + message: Message containing tool results for conversation history + """ + super().__init__({"callback": {"message": message}}) + + +class ForceStopEvent(TypedEvent): + """Event emitted when the agent execution is forcibly stopped, either by a tool or by an exception.""" + + def __init__(self, reason: str | Exception) -> None: + """Initialize with the reason for forced stop. + + Args: + reason: String description or exception that caused the forced stop + """ + super().__init__( + { + "callback": { + "force_stop": True, + "force_stop_reason": str(reason), + # "force_stop_reason_exception": reason if reason and isinstance(reason, Exception) else MISSING, + } + } + ) diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py new file mode 100644 index 000000000..d63dd97d4 --- /dev/null +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -0,0 +1,159 @@ +import unittest.mock +from unittest.mock import ANY, MagicMock, call + +import pytest + +import strands +from strands import Agent +from strands.agent import AgentResult +from strands.types._events import TypedEvent +from strands.types.exceptions import ModelThrottledException +from tests.fixtures.mocked_model_provider import MockedModelProvider + + +@pytest.fixture +def mock_time(): + with unittest.mock.patch.object(strands.event_loop.event_loop, "time") as mock: + yield mock + + +@pytest.mark.asyncio +async def test_stream_async_e2e(alist, mock_time): + @strands.tool + def fake_tool(agent: Agent): + return "Done!" + + mock_provider = MockedModelProvider( + [ + {"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}, + {"role": "assistant", "content": [{"text": "Okay invoking tool!"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"name": "fake_tool", "toolUseId": "123", "input": {}}}], + }, + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, + ] + ) + model = MagicMock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + mock_provider.stream([]), + ] + + mock_callback = unittest.mock.Mock() + agent = Agent(model=model, tools=[fake_tool], callback_handler=mock_callback) + + stream = agent.stream_async("Do the stuff", arg1=1013) + + # Base object with common properties + throttle_props = { + "agent": ANY, + "event_loop_cycle_id": ANY, + "event_loop_cycle_span": ANY, + "event_loop_cycle_trace": ANY, + "arg1": 1013, + "request_state": {}, + } + + tru_events = await alist(stream) + exp_events = [ + {"arg1": 1013, "init_event_loop": True}, + {"start": True}, + {"start_event_loop": True}, + {"event_loop_throttled_delay": 8, **throttle_props}, + {"event_loop_throttled_delay": 16, **throttle_props}, + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"redactContent": {"redactUserContentMessage": "BLOCKED!"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "INPUT BLOCKED!"}}}}, + { + "agent": ANY, + "arg1": 1013, + "data": "INPUT BLOCKED!", + "delta": {"text": "INPUT BLOCKED!"}, + "event_loop_cycle_id": ANY, + "event_loop_cycle_span": ANY, + "event_loop_cycle_trace": ANY, + "request_state": {}, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "guardrail_intervened"}}}, + {"message": {"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}}, + { + "result": AgentResult( + stop_reason="guardrail_intervened", + message={"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}, + metrics=ANY, + state={}, + ), + }, + ] + assert tru_events == exp_events + + exp_calls = [call(**event) for event in exp_events] + act_calls = mock_callback.call_args_list + assert act_calls == exp_calls + + # Ensure that all events coming out of the agent are *not* typed events + typed_events = [event for event in tru_events if isinstance(event, TypedEvent)] + assert typed_events == [] + + +@pytest.mark.asyncio +async def test_event_loop_cycle_text_response_throttling_early_end( + agenerator, + alist, + mock_time, +): + model = MagicMock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ] + + mock_callback = unittest.mock.Mock() + with pytest.raises(ModelThrottledException): + agent = Agent(model=model, callback_handler=mock_callback) + + # Because we're throwing an exception, we manually collect the items here + tru_events = [] + stream = agent.stream_async("Do the stuff", arg1=1013) + async for event in stream: + tru_events.append(event) + + # Base object with common properties + common_props = { + "agent": ANY, + "event_loop_cycle_id": ANY, + "event_loop_cycle_span": ANY, + "event_loop_cycle_trace": ANY, + "arg1": 1013, + "request_state": {}, + } + + exp_events = [ + {"init_event_loop": True, "arg1": 1013}, + {"start": True}, + {"start_event_loop": True}, + {"event_loop_throttled_delay": 8, **common_props}, + {"event_loop_throttled_delay": 16, **common_props}, + {"event_loop_throttled_delay": 32, **common_props}, + {"event_loop_throttled_delay": 64, **common_props}, + {"event_loop_throttled_delay": 128, **common_props}, + {"force_stop": True, "force_stop_reason": "ThrottlingException | ConverseStream"}, + ] + + assert tru_events == exp_events + + exp_calls = [call(**event) for event in exp_events] + act_calls = mock_callback.call_args_list + assert act_calls == exp_calls + + # Ensure that all events coming out of the agent are *not* typed events + typed_events = [event for event in tru_events if isinstance(event, TypedEvent)] + assert typed_events == [] diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 67ea5940a..a4a8af09a 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -668,62 +668,71 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): ) agent("test") - callback_handler.assert_has_calls( - [ - unittest.mock.call(init_event_loop=True), - unittest.mock.call(start=True), - unittest.mock.call(start_event_loop=True), - unittest.mock.call( - event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}} - ), - unittest.mock.call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), - unittest.mock.call( - agent=agent, - current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, - delta={"toolUse": {"input": '{"value"}'}}, - event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, - request_state={}, - ), - unittest.mock.call(event={"contentBlockStop": {}}), - unittest.mock.call(event={"contentBlockStart": {"start": {}}}), - unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), - unittest.mock.call( - agent=agent, - delta={"reasoningContent": {"text": "value"}}, - event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, - reasoning=True, - reasoningText="value", - request_state={}, - ), - unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), - unittest.mock.call( - agent=agent, - delta={"reasoningContent": {"signature": "value"}}, - event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, - reasoning=True, - reasoning_signature="value", - request_state={}, - ), - unittest.mock.call(event={"contentBlockStop": {}}), - unittest.mock.call(event={"contentBlockStart": {"start": {}}}), - unittest.mock.call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), - unittest.mock.call( - agent=agent, - data="value", - delta={"text": "value"}, - event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, - request_state={}, - ), - unittest.mock.call(event={"contentBlockStop": {}}), - unittest.mock.call( + assert callback_handler.call_args_list == [ + unittest.mock.call(init_event_loop=True), + unittest.mock.call(start=True), + unittest.mock.call(start_event_loop=True), + unittest.mock.call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}), + unittest.mock.call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), + unittest.mock.call( + agent=agent, + current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, + delta={"toolUse": {"input": '{"value"}'}}, + event_loop_cycle_id=unittest.mock.ANY, + event_loop_cycle_span=unittest.mock.ANY, + event_loop_cycle_trace=unittest.mock.ANY, + request_state={}, + ), + unittest.mock.call(event={"contentBlockStop": {}}), + unittest.mock.call(event={"contentBlockStart": {"start": {}}}), + unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), + unittest.mock.call( + agent=agent, + delta={"reasoningContent": {"text": "value"}}, + event_loop_cycle_id=unittest.mock.ANY, + event_loop_cycle_span=unittest.mock.ANY, + event_loop_cycle_trace=unittest.mock.ANY, + reasoning=True, + reasoningText="value", + request_state={}, + ), + unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), + unittest.mock.call( + agent=agent, + delta={"reasoningContent": {"signature": "value"}}, + event_loop_cycle_id=unittest.mock.ANY, + event_loop_cycle_span=unittest.mock.ANY, + event_loop_cycle_trace=unittest.mock.ANY, + reasoning=True, + reasoning_signature="value", + request_state={}, + ), + unittest.mock.call(event={"contentBlockStop": {}}), + unittest.mock.call(event={"contentBlockStart": {"start": {}}}), + unittest.mock.call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), + unittest.mock.call( + agent=agent, + data="value", + delta={"text": "value"}, + event_loop_cycle_id=unittest.mock.ANY, + event_loop_cycle_span=unittest.mock.ANY, + event_loop_cycle_trace=unittest.mock.ANY, + request_state={}, + ), + unittest.mock.call(event={"contentBlockStop": {}}), + unittest.mock.call( + message={ + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}, + {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, + {"text": "value"}, + ], + }, + ), + unittest.mock.call( + result=AgentResult( + stop_reason="end_turn", message={ "role": "assistant", "content": [ @@ -732,9 +741,11 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): {"text": "value"}, ], }, - ), - ], - ) + metrics=unittest.mock.ANY, + state={}, + ) + ), + ] @pytest.mark.asyncio diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 7760c498a..fd9548dae 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -4,6 +4,7 @@ import strands import strands.event_loop +from strands.types._events import TypedEvent from strands.types.streaming import ( ContentBlockDeltaEvent, ContentBlockStartEvent, @@ -562,6 +563,10 @@ async def test_process_stream(response, exp_events, agenerator, alist): tru_events = await alist(stream) assert tru_events == exp_events + # Ensure that we're getting typed events coming out of process_stream + non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)] + assert non_typed_events == [] + @pytest.mark.asyncio async def test_stream_messages(agenerator, alist): @@ -624,3 +629,7 @@ async def test_stream_messages(agenerator, alist): None, "test prompt", ) + + # Ensure that we're getting typed events coming out of process_stream + non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)] + assert non_typed_events == [] From d9f8d8a76c80eb5296b4a60f778d62192241c128 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 28 Aug 2025 09:47:17 -0400 Subject: [PATCH 058/104] summarization manager - add summary prompt to messages (#698) * summarization manager - add summary prompt to messages * summarize conversation - assistant to user role * fix test * add period --- .../conversation_manager/summarizing_conversation_manager.py | 5 ++--- tests/strands/agent/test_summarizing_conversation_manager.py | 4 ++-- .../test_summarizing_conversation_manager_integration.py | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 60e832215..b08b6853e 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -1,7 +1,7 @@ """Summarizing conversation history management with configurable options.""" import logging -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, cast from typing_extensions import override @@ -201,8 +201,7 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: # Use the agent to generate summary with rich content (can use tools if needed) result = summarization_agent("Please summarize this conversation.") - - return result.message + return cast(Message, {**result.message, "role": "user"}) finally: # Restore original agent state diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py index a97104412..6003a1710 100644 --- a/tests/strands/agent/test_summarizing_conversation_manager.py +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -99,7 +99,7 @@ def test_reduce_context_with_summarization(summarizing_manager, mock_agent): assert len(mock_agent.messages) == 4 # First message should be the summary - assert mock_agent.messages[0]["role"] == "assistant" + assert mock_agent.messages[0]["role"] == "user" first_content = mock_agent.messages[0]["content"][0] assert "text" in first_content and "This is a summary of the conversation." in first_content["text"] @@ -438,7 +438,7 @@ def test_reduce_context_tool_pair_adjustment_works_with_forward_search(): assert len(mock_agent.messages) == 2 # First message should be the summary - assert mock_agent.messages[0]["role"] == "assistant" + assert mock_agent.messages[0]["role"] == "user" summary_content = mock_agent.messages[0]["content"][0] assert "text" in summary_content and "This is a summary of the conversation." in summary_content["text"] diff --git a/tests_integ/test_summarizing_conversation_manager_integration.py b/tests_integ/test_summarizing_conversation_manager_integration.py index 719520b8d..b205c723f 100644 --- a/tests_integ/test_summarizing_conversation_manager_integration.py +++ b/tests_integ/test_summarizing_conversation_manager_integration.py @@ -160,7 +160,7 @@ def test_summarization_with_context_overflow(model): # First message should be the summary (assistant message) summary_message = agent.messages[0] - assert summary_message["role"] == "assistant" + assert summary_message["role"] == "user" assert len(summary_message["content"]) > 0 # Verify the summary contains actual text content @@ -362,7 +362,7 @@ def test_dedicated_summarization_agent(model, summarization_model): # Get the summary message summary_message = agent.messages[0] - assert summary_message["role"] == "assistant" + assert summary_message["role"] == "user" # Extract summary text summary_text = None From 6dadbce85bbfef200bf3283810597895aa7ad2dc Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 28 Aug 2025 10:09:30 -0400 Subject: [PATCH 059/104] feat: Use TypedEvent inheritance for callback behavior (#755) Move away from "callback" nested properties in the dict and explicitly passing invocation_state migrating to behaviors on the TypedEvent: - TypedEvent.is_callback_event for determining if an event should be yielded and or invoked in the callback - TypedEvent.prepare for taking in invocation_state Customers still only get dictionaries, as we decided that this will remain an implementation detail for the time being, but this makes the events typed all the way up until *just* before we yield events back to the caller --------- Co-authored-by: Mackenzie Zastrow --- src/strands/agent/agent.py | 31 +-- src/strands/event_loop/event_loop.py | 22 +- src/strands/tools/executors/_executor.py | 23 +- src/strands/tools/executors/concurrent.py | 7 +- src/strands/tools/executors/sequential.py | 7 +- src/strands/types/_events.py | 145 +++++++++++-- tests/strands/agent/test_agent.py | 15 +- tests/strands/event_loop/test_event_loop.py | 2 +- tests/strands/event_loop/test_streaming.py | 196 +++++++----------- .../tools/executors/test_concurrent.py | 16 +- .../strands/tools/executors/test_executor.py | 28 +-- .../tools/executors/test_sequential.py | 11 +- 12 files changed, 288 insertions(+), 215 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8233c4bfe..1e64f5adb 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -50,7 +50,7 @@ from ..tools.executors._executor import ToolExecutor from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher -from ..types._events import InitEventLoopEvent +from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException @@ -576,13 +576,16 @@ async def stream_async( events = self._run_loop(messages, invocation_state=kwargs) async for event in events: - if "callback" in event: - callback_handler(**event["callback"]) - yield event["callback"] + event.prepare(invocation_state=kwargs) + + if event.is_callback_event: + as_dict = event.as_dict() + callback_handler(**as_dict) + yield as_dict result = AgentResult(*event["stop"]) callback_handler(result=result) - yield {"result": result} + yield AgentResultEvent(result=result).as_dict() self._end_agent_trace_span(response=result) @@ -590,9 +593,7 @@ async def stream_async( self._end_agent_trace_span(error=e) raise - async def _run_loop( - self, messages: Messages, invocation_state: dict[str, Any] - ) -> AsyncGenerator[dict[str, Any], None]: + async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. Args: @@ -605,7 +606,7 @@ async def _run_loop( self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) try: - yield InitEventLoopEvent(invocation_state) + yield InitEventLoopEvent() for message in messages: self._append_message(message) @@ -616,13 +617,13 @@ async def _run_loop( # Signal from the model provider that the message sent by the user should be redacted, # likely due to a guardrail. if ( - event.get("callback") - and event["callback"].get("event") - and event["callback"]["event"].get("redactContent") - and event["callback"]["event"]["redactContent"].get("redactUserContentMessage") + isinstance(event, ModelStreamChunkEvent) + and event.chunk + and event.chunk.get("redactContent") + and event.chunk["redactContent"].get("redactUserContentMessage") ): self.messages[-1]["content"] = [ - {"text": event["callback"]["event"]["redactContent"]["redactUserContentMessage"]} + {"text": str(event.chunk["redactContent"]["redactUserContentMessage"])} ] if self._session_manager: self._session_manager.redact_latest_message(self.messages[-1], self) @@ -632,7 +633,7 @@ async def _run_loop( self.conversation_manager.apply_management(self) self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) - async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: + async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index a166902eb..a99ecc8a6 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -30,9 +30,11 @@ EventLoopThrottleEvent, ForceStopEvent, ModelMessageEvent, + ModelStopReason, StartEvent, StartEventLoopEvent, ToolResultMessageEvent, + TypedEvent, ) from ..types.content import Message from ..types.exceptions import ( @@ -56,7 +58,7 @@ MAX_DELAY = 240 # 4 minutes -async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: +async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute a single cycle of the event loop. This core function processes a single conversation turn, handling model inference, tool execution, and error @@ -139,17 +141,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) try: - # TODO: To maintain backwards compatibility, we need to combine the stream event with invocation_state - # before yielding to the callback handler. This will be revisited when migrating to strongly - # typed events. async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): - if "callback" in event: - yield { - "callback": { - **event["callback"], - **(invocation_state if "delta" in event["callback"] else {}), - } - } + if not isinstance(event, ModelStopReason): + yield event stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) @@ -198,7 +192,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> time.sleep(current_delay) current_delay = min(current_delay * 2, MAX_DELAY) - yield EventLoopThrottleEvent(delay=current_delay, invocation_state=invocation_state) + yield EventLoopThrottleEvent(delay=current_delay) else: raise e @@ -280,7 +274,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) -async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: +async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Make a recursive call to event_loop_cycle with the current state. This function is used when the event loop needs to continue processing after tool execution. @@ -321,7 +315,7 @@ async def _handle_tool_execution( cycle_span: Any, cycle_start_time: float, invocation_state: dict[str, Any], -) -> AsyncGenerator[dict[str, Any], None]: +) -> AsyncGenerator[TypedEvent, None]: """Handles the execution of tools requested by the model during an event loop cycle. Args: diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 9999b77fc..701a3bac0 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -7,15 +7,16 @@ import abc import logging import time -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator, cast from opentelemetry import trace as trace_api from ...experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent from ...telemetry.metrics import Trace from ...telemetry.tracer import get_tracer +from ...types._events import ToolResultEvent, ToolStreamEvent, TypedEvent from ...types.content import Message -from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse +from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse if TYPE_CHECKING: # pragma: no cover from ...agent import Agent @@ -33,7 +34,7 @@ async def _stream( tool_results: list[ToolResult], invocation_state: dict[str, Any], **kwargs: Any, - ) -> ToolGenerator: + ) -> AsyncGenerator[TypedEvent, None]: """Stream tool events. This method adds additional logic to the stream invocation including: @@ -113,12 +114,12 @@ async def _stream( result=result, ) ) - yield after_event.result + yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) return async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): - yield event + yield ToolStreamEvent(tool_use, event) result = cast(ToolResult, event) @@ -131,7 +132,8 @@ async def _stream( result=result, ) ) - yield after_event.result + + yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) except Exception as e: @@ -151,7 +153,7 @@ async def _stream( exception=e, ) ) - yield after_event.result + yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) @staticmethod @@ -163,7 +165,7 @@ async def _stream_with_trace( cycle_span: Any, invocation_state: dict[str, Any], **kwargs: Any, - ) -> ToolGenerator: + ) -> AsyncGenerator[TypedEvent, None]: """Execute tool with tracing and metrics collection. Args: @@ -190,7 +192,8 @@ async def _stream_with_trace( async for event in ToolExecutor._stream(agent, tool_use, tool_results, invocation_state, **kwargs): yield event - result = cast(ToolResult, event) + result_event = cast(ToolResultEvent, event) + result = result_event.tool_result tool_success = result.get("status") == "success" tool_duration = time.time() - tool_start_time @@ -210,7 +213,7 @@ def _execute( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], - ) -> ToolGenerator: + ) -> AsyncGenerator[TypedEvent, None]: """Execute the given tools according to this executor's strategy. Args: diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 7d5dd7fe7..767071bae 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -1,12 +1,13 @@ """Concurrent tool executor implementation.""" import asyncio -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, AsyncGenerator from typing_extensions import override from ...telemetry.metrics import Trace -from ...types.tools import ToolGenerator, ToolResult, ToolUse +from ...types._events import TypedEvent +from ...types.tools import ToolResult, ToolUse from ._executor import ToolExecutor if TYPE_CHECKING: # pragma: no cover @@ -25,7 +26,7 @@ async def _execute( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], - ) -> ToolGenerator: + ) -> AsyncGenerator[TypedEvent, None]: """Execute tools concurrently. Args: diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index 55b26f6d3..60e5c7fa7 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -1,11 +1,12 @@ """Sequential tool executor implementation.""" -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, AsyncGenerator from typing_extensions import override from ...telemetry.metrics import Trace -from ...types.tools import ToolGenerator, ToolResult, ToolUse +from ...types._events import TypedEvent +from ...types.tools import ToolResult, ToolUse from ._executor import ToolExecutor if TYPE_CHECKING: # pragma: no cover @@ -24,7 +25,7 @@ async def _execute( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], - ) -> ToolGenerator: + ) -> AsyncGenerator[TypedEvent, None]: """Execute tools sequentially. Args: diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 1bddc5877..cc2330a81 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -5,15 +5,18 @@ agent lifecycle. """ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast + +from typing_extensions import override from ..telemetry import EventLoopMetrics from .content import Message from .event_loop import Metrics, StopReason, Usage from .streaming import ContentBlockDelta, StreamEvent +from .tools import ToolResult, ToolUse if TYPE_CHECKING: - pass + from ..agent import AgentResult class TypedEvent(dict): @@ -27,6 +30,23 @@ def __init__(self, data: dict[str, Any] | None = None) -> None: """ super().__init__(data or {}) + @property + def is_callback_event(self) -> bool: + """True if this event should trigger the callback_handler to fire.""" + return True + + def as_dict(self) -> dict: + """Convert this event to a raw dictionary for emitting purposes.""" + return {**self} + + def prepare(self, invocation_state: dict) -> None: + """Prepare the event for emission by adding invocation state. + + This allows a subset of events to merge with the invocation_state without needing to + pass around the invocation_state throughout the system. + """ + ... + class InitEventLoopEvent(TypedEvent): """Event emitted at the very beginning of agent execution. @@ -38,9 +58,13 @@ class InitEventLoopEvent(TypedEvent): invocation_state: The invocation state passed into the request """ - def __init__(self, invocation_state: dict) -> None: + def __init__(self) -> None: """Initialize the event loop initialization event.""" - super().__init__({"callback": {"init_event_loop": True, **invocation_state}}) + super().__init__({"init_event_loop": True}) + + @override + def prepare(self, invocation_state: dict) -> None: + self.update(invocation_state) class StartEvent(TypedEvent): @@ -55,7 +79,7 @@ class StartEvent(TypedEvent): def __init__(self) -> None: """Initialize the event loop start event.""" - super().__init__({"callback": {"start": True}}) + super().__init__({"start": True}) class StartEventLoopEvent(TypedEvent): @@ -67,7 +91,7 @@ class StartEventLoopEvent(TypedEvent): def __init__(self) -> None: """Initialize the event loop processing start event.""" - super().__init__({"callback": {"start_event_loop": True}}) + super().__init__({"start_event_loop": True}) class ModelStreamChunkEvent(TypedEvent): @@ -79,7 +103,11 @@ def __init__(self, chunk: StreamEvent) -> None: Args: chunk: Incremental streaming data from the model response """ - super().__init__({"callback": {"event": chunk}}) + super().__init__({"event": chunk}) + + @property + def chunk(self) -> StreamEvent: + return cast(StreamEvent, self.get("event")) class ModelStreamEvent(TypedEvent): @@ -97,13 +125,23 @@ def __init__(self, delta_data: dict[str, Any]) -> None: """ super().__init__(delta_data) + @property + def is_callback_event(self) -> bool: + # Only invoke a callback if we're non-empty + return len(self.keys()) > 0 + + @override + def prepare(self, invocation_state: dict) -> None: + if "delta" in self: + self.update(invocation_state) + class ToolUseStreamEvent(ModelStreamEvent): """Event emitted during tool use input streaming.""" def __init__(self, delta: ContentBlockDelta, current_tool_use: dict[str, Any]) -> None: """Initialize with delta and current tool use state.""" - super().__init__({"callback": {"delta": delta, "current_tool_use": current_tool_use}}) + super().__init__({"delta": delta, "current_tool_use": current_tool_use}) class TextStreamEvent(ModelStreamEvent): @@ -111,7 +149,7 @@ class TextStreamEvent(ModelStreamEvent): def __init__(self, delta: ContentBlockDelta, text: str) -> None: """Initialize with delta and text content.""" - super().__init__({"callback": {"data": text, "delta": delta}}) + super().__init__({"data": text, "delta": delta}) class ReasoningTextStreamEvent(ModelStreamEvent): @@ -119,7 +157,7 @@ class ReasoningTextStreamEvent(ModelStreamEvent): def __init__(self, delta: ContentBlockDelta, reasoning_text: str | None) -> None: """Initialize with delta and reasoning text.""" - super().__init__({"callback": {"reasoningText": reasoning_text, "delta": delta, "reasoning": True}}) + super().__init__({"reasoningText": reasoning_text, "delta": delta, "reasoning": True}) class ReasoningSignatureStreamEvent(ModelStreamEvent): @@ -127,7 +165,7 @@ class ReasoningSignatureStreamEvent(ModelStreamEvent): def __init__(self, delta: ContentBlockDelta, reasoning_signature: str | None) -> None: """Initialize with delta and reasoning signature.""" - super().__init__({"callback": {"reasoning_signature": reasoning_signature, "delta": delta, "reasoning": True}}) + super().__init__({"reasoning_signature": reasoning_signature, "delta": delta, "reasoning": True}) class ModelStopReason(TypedEvent): @@ -150,6 +188,11 @@ def __init__( """ super().__init__({"stop": (stop_reason, message, usage, metrics)}) + @property + @override + def is_callback_event(self) -> bool: + return False + class EventLoopStopEvent(TypedEvent): """Event emitted when the agent execution completes normally.""" @@ -171,18 +214,76 @@ def __init__( """ super().__init__({"stop": (stop_reason, message, metrics, request_state)}) + @property + @override + def is_callback_event(self) -> bool: + return False + class EventLoopThrottleEvent(TypedEvent): """Event emitted when the event loop is throttled due to rate limiting.""" - def __init__(self, delay: int, invocation_state: dict[str, Any]) -> None: + def __init__(self, delay: int) -> None: """Initialize with the throttle delay duration. Args: delay: Delay in seconds before the next retry attempt - invocation_state: The invocation state passed into the request """ - super().__init__({"callback": {"event_loop_throttled_delay": delay, **invocation_state}}) + super().__init__({"event_loop_throttled_delay": delay}) + + @override + def prepare(self, invocation_state: dict) -> None: + self.update(invocation_state) + + +class ToolResultEvent(TypedEvent): + """Event emitted when a tool execution completes.""" + + def __init__(self, tool_result: ToolResult) -> None: + """Initialize with the completed tool result. + + Args: + tool_result: Final result from the tool execution + """ + super().__init__({"tool_result": tool_result}) + + @property + def tool_use_id(self) -> str: + """The toolUseId associated with this result.""" + return cast(str, cast(ToolResult, self.get("tool_result")).get("toolUseId")) + + @property + def tool_result(self) -> ToolResult: + """Final result from the completed tool execution.""" + return cast(ToolResult, self.get("tool_result")) + + @property + @override + def is_callback_event(self) -> bool: + return False + + +class ToolStreamEvent(TypedEvent): + """Event emitted when a tool yields sub-events as part of tool execution.""" + + def __init__(self, tool_use: ToolUse, tool_sub_event: Any) -> None: + """Initialize with tool streaming data. + + Args: + tool_use: The tool invocation producing the stream + tool_sub_event: The yielded event from the tool execution + """ + super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_event": tool_sub_event}) + + @property + def tool_use_id(self) -> str: + """The toolUseId associated with this stream.""" + return cast(str, cast(ToolUse, self.get("tool_stream_tool_use")).get("toolUseId")) + + @property + @override + def is_callback_event(self) -> bool: + return False class ModelMessageEvent(TypedEvent): @@ -198,7 +299,7 @@ def __init__(self, message: Message) -> None: Args: message: The response message from the model """ - super().__init__({"callback": {"message": message}}) + super().__init__({"message": message}) class ToolResultMessageEvent(TypedEvent): @@ -215,7 +316,7 @@ def __init__(self, message: Any) -> None: Args: message: Message containing tool results for conversation history """ - super().__init__({"callback": {"message": message}}) + super().__init__({"message": message}) class ForceStopEvent(TypedEvent): @@ -229,10 +330,12 @@ def __init__(self, reason: str | Exception) -> None: """ super().__init__( { - "callback": { - "force_stop": True, - "force_stop_reason": str(reason), - # "force_stop_reason_exception": reason if reason and isinstance(reason, Exception) else MISSING, - } + "force_stop": True, + "force_stop_reason": str(reason), } ) + + +class AgentResultEvent(TypedEvent): + def __init__(self, result: "AgentResult"): + super().__init__({"result": result}) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index a4a8af09a..a8561abe4 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -19,6 +19,7 @@ from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.telemetry.tracer import serialize +from strands.types._events import EventLoopStopEvent, ModelStreamEvent from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType @@ -406,7 +407,7 @@ async def check_invocation_state(**kwargs): assert invocation_state["agent"] == agent # Return expected values from event_loop_cycle - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) mock_event_loop_cycle.side_effect = check_invocation_state @@ -1144,12 +1145,12 @@ async def test_stream_async_returns_all_events(mock_event_loop_cycle, alist): # Define the side effect to simulate callback handler being called multiple times async def test_event_loop(*args, **kwargs): - yield {"callback": {"data": "First chunk"}} - yield {"callback": {"data": "Second chunk"}} - yield {"callback": {"data": "Final chunk", "complete": True}} + yield ModelStreamEvent({"data": "First chunk"}) + yield ModelStreamEvent({"data": "Second chunk"}) + yield ModelStreamEvent({"data": "Final chunk", "complete": True}) # Return expected values from event_loop_cycle - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) mock_event_loop_cycle.side_effect = test_event_loop mock_callback = unittest.mock.Mock() @@ -1234,7 +1235,7 @@ async def check_invocation_state(**kwargs): invocation_state = kwargs["invocation_state"] assert invocation_state["some_value"] == "a_value" # Return expected values from event_loop_cycle - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) mock_event_loop_cycle.side_effect = check_invocation_state @@ -1366,7 +1367,7 @@ async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_trac mock_get_tracer.return_value = mock_tracer async def test_event_loop(*args, **kwargs): - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {})} + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {}) mock_event_loop_cycle.side_effect = test_event_loop diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index c76514ac8..68f9cc5ab 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -486,7 +486,7 @@ async def test_cycle_exception( ] tru_stop_event = None - exp_stop_event = {"callback": {"force_stop": True, "force_stop_reason": "Invalid error presented"}} + exp_stop_event = {"force_stop": True, "force_stop_reason": "Invalid error presented"} with pytest.raises(EventLoopException): stream = strands.event_loop.event_loop.event_loop_cycle( diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index fd9548dae..fdd560b22 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -146,7 +146,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) ], ) def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, callback_args): - exp_callback_event = {"callback": {**callback_args, "delta": event["delta"]}} if callback_args else {} + exp_callback_event = {**callback_args, "delta": event["delta"]} if callback_args else {} tru_updated_state, tru_callback_event = strands.event_loop.streaming.handle_content_block_delta(event, state) @@ -316,85 +316,71 @@ def test_extract_usage_metrics_with_cache_tokens(): ], [ { - "callback": { - "event": { - "messageStart": { - "role": "assistant", - }, + "event": { + "messageStart": { + "role": "assistant", }, }, }, { - "callback": { - "event": { - "contentBlockStart": { - "start": { - "toolUse": { - "name": "test", - "toolUseId": "123", - }, + "event": { + "contentBlockStart": { + "start": { + "toolUse": { + "name": "test", + "toolUseId": "123", }, }, }, }, }, { - "callback": { - "event": { - "contentBlockDelta": { - "delta": { - "toolUse": { - "input": '{"key": "value"}', - }, + "event": { + "contentBlockDelta": { + "delta": { + "toolUse": { + "input": '{"key": "value"}', }, }, }, }, }, { - "callback": { - "current_tool_use": { - "input": { - "key": "value", - }, - "name": "test", - "toolUseId": "123", + "current_tool_use": { + "input": { + "key": "value", }, - "delta": { - "toolUse": { - "input": '{"key": "value"}', - }, + "name": "test", + "toolUseId": "123", + }, + "delta": { + "toolUse": { + "input": '{"key": "value"}', }, }, }, { - "callback": { - "event": { - "contentBlockStop": {}, - }, + "event": { + "contentBlockStop": {}, }, }, { - "callback": { - "event": { - "messageStop": { - "stopReason": "tool_use", - }, + "event": { + "messageStop": { + "stopReason": "tool_use", }, }, }, { - "callback": { - "event": { - "metadata": { - "metrics": { - "latencyMs": 1, - }, - "usage": { - "inputTokens": 1, - "outputTokens": 1, - "totalTokens": 1, - }, + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, }, }, }, @@ -417,9 +403,7 @@ def test_extract_usage_metrics_with_cache_tokens(): [{}], [ { - "callback": { - "event": {}, - }, + "event": {}, }, { "stop": ( @@ -463,80 +447,64 @@ def test_extract_usage_metrics_with_cache_tokens(): ], [ { - "callback": { - "event": { - "messageStart": { - "role": "assistant", - }, + "event": { + "messageStart": { + "role": "assistant", }, }, }, { - "callback": { - "event": { - "contentBlockStart": { - "start": {}, - }, + "event": { + "contentBlockStart": { + "start": {}, }, }, }, { - "callback": { - "event": { - "contentBlockDelta": { - "delta": { - "text": "Hello!", - }, + "event": { + "contentBlockDelta": { + "delta": { + "text": "Hello!", }, }, }, }, { - "callback": { - "data": "Hello!", - "delta": { - "text": "Hello!", - }, + "data": "Hello!", + "delta": { + "text": "Hello!", }, }, { - "callback": { - "event": { - "contentBlockStop": {}, - }, + "event": { + "contentBlockStop": {}, }, }, { - "callback": { - "event": { - "messageStop": { - "stopReason": "guardrail_intervened", - }, + "event": { + "messageStop": { + "stopReason": "guardrail_intervened", }, }, }, { - "callback": { - "event": { - "redactContent": { - "redactAssistantContentMessage": "REDACTED.", - "redactUserContentMessage": "REDACTED", - }, + "event": { + "redactContent": { + "redactAssistantContentMessage": "REDACTED.", + "redactUserContentMessage": "REDACTED", }, }, }, { - "callback": { - "event": { - "metadata": { - "metrics": { - "latencyMs": 1, - }, - "usage": { - "inputTokens": 1, - "outputTokens": 1, - "totalTokens": 1, - }, + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, }, }, }, @@ -588,29 +556,23 @@ async def test_stream_messages(agenerator, alist): tru_events = await alist(stream) exp_events = [ { - "callback": { - "event": { - "contentBlockDelta": { - "delta": { - "text": "test", - }, + "event": { + "contentBlockDelta": { + "delta": { + "text": "test", }, }, }, }, { - "callback": { - "data": "test", - "delta": { - "text": "test", - }, + "data": "test", + "delta": { + "text": "test", }, }, { - "callback": { - "event": { - "contentBlockStop": {}, - }, + "event": { + "contentBlockStop": {}, }, }, { diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index 7e0d6c2df..140537add 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -1,6 +1,8 @@ import pytest from strands.tools.executors import ConcurrentToolExecutor +from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types.tools import ToolUse @pytest.fixture @@ -12,21 +14,21 @@ def executor(): async def test_concurrent_executor_execute( executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist ): - tool_uses = [ + tool_uses: list[ToolUse] = [ {"name": "weather_tool", "toolUseId": "1", "input": {}}, {"name": "temperature_tool", "toolUseId": "2", "input": {}}, ] stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) - tru_events = sorted(await alist(stream), key=lambda event: event.get("toolUseId")) + tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id) exp_events = [ - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, - {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, - {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, + ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), + ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ] assert tru_events == exp_events tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId")) - exp_results = [exp_events[1], exp_events[3]] + exp_results = [exp_events[1].tool_result, exp_events[3].tool_result] assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index edbad3939..56caa950a 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -6,6 +6,8 @@ from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent from strands.telemetry.metrics import Trace from strands.tools.executors._executor import ToolExecutor +from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types.tools import ToolUse @pytest.fixture @@ -32,18 +34,18 @@ def tracer(): async def test_executor_stream_yields_result( executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist ): - tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}} + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} stream = executor._stream(agent, tool_use, tool_results, invocation_state) tru_events = await alist(stream) exp_events = [ - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[-1]] + exp_results = [exp_events[-1].tool_result] assert tru_results == exp_results tru_hook_events = hook_events @@ -73,11 +75,11 @@ async def test_executor_stream_yields_tool_error( stream = executor._stream(agent, tool_use, tool_results, invocation_state) tru_events = await alist(stream) - exp_events = [{"toolUseId": "1", "status": "error", "content": [{"text": "Error: Tool error"}]}] + exp_events = [ToolResultEvent({"toolUseId": "1", "status": "error", "content": [{"text": "Error: Tool error"}]})] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[-1]] + exp_results = [exp_events[-1].tool_result] assert tru_results == exp_results tru_hook_after_event = hook_events[-1] @@ -98,11 +100,13 @@ async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results stream = executor._stream(agent, tool_use, tool_results, invocation_state) tru_events = await alist(stream) - exp_events = [{"toolUseId": "1", "status": "error", "content": [{"text": "Unknown tool: unknown_tool"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "1", "status": "error", "content": [{"text": "Unknown tool: unknown_tool"}]}) + ] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[-1]] + exp_results = [exp_events[-1].tool_result] assert tru_results == exp_results tru_hook_after_event = hook_events[-1] @@ -120,18 +124,18 @@ async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results async def test_executor_stream_with_trace( executor, tracer, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist ): - tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}} + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} stream = executor._stream_with_trace(agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state) tru_events = await alist(stream) exp_events = [ - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[-1]] + exp_results = [exp_events[-1].tool_result] assert tru_results == exp_results tracer.start_tool_call_span.assert_called_once_with(tool_use, cycle_span) diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py index d9b32c129..d4e98223e 100644 --- a/tests/strands/tools/executors/test_sequential.py +++ b/tests/strands/tools/executors/test_sequential.py @@ -1,6 +1,7 @@ import pytest from strands.tools.executors import SequentialToolExecutor +from strands.types._events import ToolResultEvent, ToolStreamEvent @pytest.fixture @@ -20,13 +21,13 @@ async def test_sequential_executor_execute( tru_events = await alist(stream) exp_events = [ - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, - {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, - {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, + ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), + ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[1], exp_events[2]] + exp_results = [exp_events[1].tool_result, exp_events[3].tool_result] assert tru_results == exp_results From 47faba0911f00cecaff5cee8145530818a65c5e7 Mon Sep 17 00:00:00 2001 From: Laith Al-Saadoon <9553966+theagenticguy@users.noreply.github.com> Date: Thu, 28 Aug 2025 08:29:20 -0700 Subject: [PATCH 060/104] feat: claude citation support with BedrockModel (#631) * feat: add citations to document content * feat: addes citation types * chore: remove uv.lock * test: add letter.pdf for test-integ * feat: working bedrock citations feature * feat: fail early for citations with incompatible models * fix: validates model ids with cross region inference ids * Apply suggestion from @Unshure Co-authored-by: Nick Clegg * fix: addresses comments * removes client exception handling * moves citation into text elif * puts relative imports back * fix: tests failing * Update src/strands/models/bedrock.py Removes old comment Co-authored-by: Nick Clegg * Update src/strands/models/bedrock.py Removes old comment Co-authored-by: Nick Clegg * Update imports in bedrock.py Refactor imports in bedrock.py to include CitationsDelta. * feat: typed citation events --------- Co-authored-by: Nick Clegg --- src/strands/agent/agent_result.py | 1 - src/strands/event_loop/streaming.py | 16 +++ src/strands/models/bedrock.py | 29 +++- src/strands/types/_events.py | 9 ++ src/strands/types/citations.py | 152 +++++++++++++++++++++ src/strands/types/content.py | 3 + src/strands/types/media.py | 8 +- src/strands/types/streaming.py | 37 +++++ tests/strands/event_loop/test_streaming.py | 29 ++++ tests_integ/conftest.py | 7 + tests_integ/letter.pdf | Bin 0 -> 100738 bytes tests_integ/models/test_model_bedrock.py | 49 ++++++- tests_integ/test_max_tokens_reached.py | 2 +- 13 files changed, 332 insertions(+), 10 deletions(-) create mode 100644 src/strands/types/citations.py create mode 100644 tests_integ/letter.pdf diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index e28e1c5b8..f3758c8d2 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -42,5 +42,4 @@ def __str__(self) -> str: for item in content_array: if isinstance(item, dict) and "text" in item: result += item.get("text", "") + "\n" - return result diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 7507c6d75..efe094e5f 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -6,6 +6,7 @@ from ..models.model import Model from ..types._events import ( + CitationStreamEvent, ModelStopReason, ModelStreamChunkEvent, ModelStreamEvent, @@ -15,6 +16,7 @@ ToolUseStreamEvent, TypedEvent, ) +from ..types.citations import CitationsContentBlock from ..types.content import ContentBlock, Message, Messages from ..types.streaming import ( ContentBlockDeltaEvent, @@ -140,6 +142,13 @@ def handle_content_block_delta( state["text"] += delta_content["text"] typed_event = TextStreamEvent(text=delta_content["text"], delta=delta_content) + elif "citation" in delta_content: + if "citationsContent" not in state: + state["citationsContent"] = [] + + state["citationsContent"].append(delta_content["citation"]) + typed_event = CitationStreamEvent(delta=delta_content, citation=delta_content["citation"]) + elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: if "reasoningText" not in state: @@ -178,6 +187,7 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: current_tool_use = state["current_tool_use"] text = state["text"] reasoning_text = state["reasoningText"] + citations_content = state["citationsContent"] if current_tool_use: if "input" not in current_tool_use: @@ -202,6 +212,10 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: elif text: content.append({"text": text}) state["text"] = "" + if citations_content: + citations_block: CitationsContentBlock = {"citations": citations_content} + content.append({"citationsContent": citations_block}) + state["citationsContent"] = [] elif reasoning_text: content_block: ContentBlock = { @@ -275,6 +289,8 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[T "text": "", "current_tool_use": {}, "reasoningText": "", + "signature": "", + "citationsContent": [], } state["content"] = state["message"]["content"] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index ace35640a..0fe332a47 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -7,7 +7,7 @@ import json import logging import os -from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union +from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -18,8 +18,11 @@ from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Message, Messages -from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException -from ..types.streaming import StreamEvent +from ..types.exceptions import ( + ContextWindowOverflowException, + ModelThrottledException, +) +from ..types.streaming import CitationsDelta, StreamEvent from ..types.tools import ToolResult, ToolSpec from .model import Model @@ -510,7 +513,7 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera yield {"messageStart": {"role": response["output"]["message"]["role"]}} # Process content blocks - for content in response["output"]["message"]["content"]: + for content in cast(list[ContentBlock], response["output"]["message"]["content"]): # Yield contentBlockStart event if needed if "toolUse" in content: yield { @@ -553,6 +556,24 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera } } } + elif "citationsContent" in content: + # For non-streaming citations, emit text and metadata deltas in sequence + # to match streaming behavior where they flow naturally + if "content" in content["citationsContent"]: + text_content = "".join([content["text"] for content in content["citationsContent"]["content"]]) + yield { + "contentBlockDelta": {"delta": {"text": text_content}}, + } + + for citation in content["citationsContent"]["citations"]: + # Then emit citation metadata (for structure) + + citation_metadata: CitationsDelta = { + "title": citation["title"], + "location": citation["location"], + "sourceContent": citation["sourceContent"], + } + yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}} # Yield contentBlockStop event yield {"contentBlockStop": {}} diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index cc2330a81..1a7f48d4b 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -10,6 +10,7 @@ from typing_extensions import override from ..telemetry import EventLoopMetrics +from .citations import Citation from .content import Message from .event_loop import Metrics, StopReason, Usage from .streaming import ContentBlockDelta, StreamEvent @@ -152,6 +153,14 @@ def __init__(self, delta: ContentBlockDelta, text: str) -> None: super().__init__({"data": text, "delta": delta}) +class CitationStreamEvent(ModelStreamEvent): + """Event emitted during citation streaming.""" + + def __init__(self, delta: ContentBlockDelta, citation: Citation) -> None: + """Initialize with delta and citation content.""" + super().__init__({"callback": {"citation": citation, "delta": delta}}) + + class ReasoningTextStreamEvent(ModelStreamEvent): """Event emitted during reasoning text streaming.""" diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py new file mode 100644 index 000000000..b0e28f655 --- /dev/null +++ b/src/strands/types/citations.py @@ -0,0 +1,152 @@ +"""Citation type definitions for the SDK. + +These types are modeled after the Bedrock API. +""" + +from typing import List, Union + +from typing_extensions import TypedDict + + +class CitationsConfig(TypedDict): + """Configuration for enabling citations on documents. + + Attributes: + enabled: Whether citations are enabled for this document. + """ + + enabled: bool + + +class DocumentCharLocation(TypedDict, total=False): + """Specifies a character-level location within a document. + + Provides precise positioning information for cited content using + start and end character indices. + + Attributes: + documentIndex: The index of the document within the array of documents + provided in the request. Minimum value of 0. + start: The starting character position of the cited content within + the document. Minimum value of 0. + end: The ending character position of the cited content within + the document. Minimum value of 0. + """ + + documentIndex: int + start: int + end: int + + +class DocumentChunkLocation(TypedDict, total=False): + """Specifies a chunk-level location within a document. + + Provides positioning information for cited content using logical + document segments or chunks. + + Attributes: + documentIndex: The index of the document within the array of documents + provided in the request. Minimum value of 0. + start: The starting chunk identifier or index of the cited content + within the document. Minimum value of 0. + end: The ending chunk identifier or index of the cited content + within the document. Minimum value of 0. + """ + + documentIndex: int + start: int + end: int + + +class DocumentPageLocation(TypedDict, total=False): + """Specifies a page-level location within a document. + + Provides positioning information for cited content using page numbers. + + Attributes: + documentIndex: The index of the document within the array of documents + provided in the request. Minimum value of 0. + start: The starting page number of the cited content within + the document. Minimum value of 0. + end: The ending page number of the cited content within + the document. Minimum value of 0. + """ + + documentIndex: int + start: int + end: int + + +# Union type for citation locations +CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation] + + +class CitationSourceContent(TypedDict, total=False): + """Contains the actual text content from a source document. + + Contains the actual text content from a source document that is being + cited or referenced in the model's response. + + Note: + This is a UNION type, so only one of the members can be specified. + + Attributes: + text: The text content from the source document that is being cited. + """ + + text: str + + +class CitationGeneratedContent(TypedDict, total=False): + """Contains the generated text content that corresponds to a citation. + + Contains the generated text content that corresponds to or is supported + by a citation from a source document. + + Note: + This is a UNION type, so only one of the members can be specified. + + Attributes: + text: The text content that was generated by the model and is + supported by the associated citation. + """ + + text: str + + +class Citation(TypedDict, total=False): + """Contains information about a citation that references a source document. + + Citations provide traceability between the model's generated response + and the source documents that informed that response. + + Attributes: + location: The precise location within the source document where the + cited content can be found, including character positions, page + numbers, or chunk identifiers. + sourceContent: The specific content from the source document that was + referenced or cited in the generated response. + title: The title or identifier of the source document being cited. + """ + + location: CitationLocation + sourceContent: List[CitationSourceContent] + title: str + + +class CitationsContentBlock(TypedDict, total=False): + """A content block containing generated text and associated citations. + + This block type is returned when document citations are enabled, providing + traceability between the generated content and the source documents that + informed the response. + + Attributes: + citations: An array of citations that reference the source documents + used to generate the associated content. + content: The generated content that is supported by the associated + citations. + """ + + citations: List[Citation] + content: List[CitationGeneratedContent] diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 790e9094c..c3eddca4d 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -10,6 +10,7 @@ from typing_extensions import TypedDict +from .citations import CitationsContentBlock from .media import DocumentContent, ImageContent, VideoContent from .tools import ToolResult, ToolUse @@ -83,6 +84,7 @@ class ContentBlock(TypedDict, total=False): toolResult: The result for a tool request that a model makes. toolUse: Information about a tool use request from a model. video: Video to include in the message. + citationsContent: Contains the citations for a document. """ cachePoint: CachePoint @@ -94,6 +96,7 @@ class ContentBlock(TypedDict, total=False): toolResult: ToolResult toolUse: ToolUse video: VideoContent + citationsContent: CitationsContentBlock class SystemContentBlock(TypedDict, total=False): diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 29b89e5c6..69cd60cf3 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -5,10 +5,12 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Literal +from typing import Literal, Optional from typing_extensions import TypedDict +from .citations import CitationsConfig + DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] """Supported document formats.""" @@ -23,7 +25,7 @@ class DocumentSource(TypedDict): bytes: bytes -class DocumentContent(TypedDict): +class DocumentContent(TypedDict, total=False): """A document to include in a message. Attributes: @@ -35,6 +37,8 @@ class DocumentContent(TypedDict): format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] name: str source: DocumentSource + citations: Optional[CitationsConfig] + context: Optional[str] ImageFormat = Literal["png", "jpeg", "gif", "webp"] diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py index 9c99b2108..dcfd541a8 100644 --- a/src/strands/types/streaming.py +++ b/src/strands/types/streaming.py @@ -9,6 +9,7 @@ from typing_extensions import TypedDict +from .citations import CitationLocation from .content import ContentBlockStart, Role from .event_loop import Metrics, StopReason, Usage from .guardrails import Trace @@ -57,6 +58,41 @@ class ContentBlockDeltaToolUse(TypedDict): input: str +class CitationSourceContentDelta(TypedDict, total=False): + """Contains incremental updates to source content text during streaming. + + Allows clients to build up the cited content progressively during + streaming responses. + + Attributes: + text: An incremental update to the text content from the source + document that is being cited. + """ + + text: str + + +class CitationsDelta(TypedDict, total=False): + """Contains incremental updates to citation information during streaming. + + This allows clients to build up citation data progressively as the + response is generated. + + Attributes: + location: Specifies the precise location within a source document + where cited content can be found. This can include character-level + positions, page numbers, or document chunks depending on the + document type and indexing method. + sourceContent: The specific content from the source document that was + referenced or cited in the generated response. + title: The title or identifier of the source document being cited. + """ + + location: CitationLocation + sourceContent: list[CitationSourceContentDelta] + title: str + + class ReasoningContentBlockDelta(TypedDict, total=False): """Delta for reasoning content block in a streaming response. @@ -83,6 +119,7 @@ class ContentBlockDelta(TypedDict, total=False): reasoningContent: ReasoningContentBlockDelta text: str toolUse: ContentBlockDeltaToolUse + citation: CitationsDelta class ContentBlockDeltaEvent(TypedDict, total=False): diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index fdd560b22..ce12b4e98 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -164,12 +164,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {"toolUseId": "123", "name": "test", "input": '{"key": "value"}'}, "text": "", "reasoningText": "", + "citationsContent": [], }, { "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}], "current_tool_use": {}, "text": "", "reasoningText": "", + "citationsContent": [], }, ), # Tool Use - Missing input @@ -179,12 +181,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {"toolUseId": "123", "name": "test"}, "text": "", "reasoningText": "", + "citationsContent": [], }, { "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}], "current_tool_use": {}, "text": "", "reasoningText": "", + "citationsContent": [], }, ), # Text @@ -194,12 +198,31 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {}, "text": "test", "reasoningText": "", + "citationsContent": [], }, { "content": [{"text": "test"}], "current_tool_use": {}, "text": "", "reasoningText": "", + "citationsContent": [], + }, + ), + # Citations + ( + { + "content": [], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}], + }, + { + "content": [], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}], }, ), # Reasoning @@ -210,6 +233,7 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "text": "", "reasoningText": "test", "signature": "123", + "citationsContent": [], }, { "content": [{"reasoningContent": {"reasoningText": {"text": "test", "signature": "123"}}}], @@ -217,6 +241,7 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "text": "", "reasoningText": "", "signature": "123", + "citationsContent": [], }, ), # Reasoning without signature @@ -226,12 +251,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {}, "text": "", "reasoningText": "test", + "citationsContent": [], }, { "content": [{"reasoningContent": {"reasoningText": {"text": "test"}}}], "current_tool_use": {}, "text": "", "reasoningText": "", + "citationsContent": [], }, ), # Empty @@ -241,12 +268,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up "current_tool_use": {}, "text": "", "reasoningText": "", + "citationsContent": [], }, { "content": [], "current_tool_use": {}, "text": "", "reasoningText": "", + "citationsContent": [], }, ), ], diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index 61c2bf9a1..26453e1f7 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -22,6 +22,13 @@ def yellow_img(pytestconfig): return fp.read() +@pytest.fixture +def letter_pdf(pytestconfig): + path = pytestconfig.rootdir / "tests_integ/letter.pdf" + with open(path, "rb") as fp: + return fp.read() + + ## Async diff --git a/tests_integ/letter.pdf b/tests_integ/letter.pdf new file mode 100644 index 0000000000000000000000000000000000000000..d8c59f749219814f9e76c69393de15719d7f37ae GIT binary patch literal 100738 zcmagF1CV9GvNk$x+um*4wr$(CZM&z9-92sFwlV$B9K3s8yz~CMV@K>*wJI}L z;>)VsS&LLbM2wb^jvbD4Z~xak95fpXfB|4{WCh2=LoZ`#XKLtV>1ApHV5FA;FfuT( zGjq_(1K8N;MFA`<9E>dVasWm~dX>Mon3)*qB>>t0Rt5$D6FZA8A0M2loylKLfdBUg z4w~aXWQdp=+nbmw7&_TH|B)zcZ{uof=L}$>S8y^lu{3tEcLFdn@$u1%S=zXmI?;>S z7`mALlZ0MUM4LyLg;7kDQG}C8h=GZnNr;U>L{N-bP=tk3RFH*Lgn^TfM}&i&Q-oQR znU#}4kd2vzgHu>YRE&vHl#`8Jn2nuLh)9wrBJ|#PAS9x6Its}W z_?{j(GSLvg4nBe#$wUt*!C(#?={j(*NETKm2KHFcje$W)7lVWTw-?}`IXVBw$N$Z+04653e=E<($waRV zU}XE_g#XaTz`(%Fz#z)Zz|dgFa}3ZM@C{giQ4MhP@=YZ_pTQntfcl8Ks#wNalbu=4 zkbRf0x3_mlbqyVR1N)oObQz|bL4PVI3fZ?vveKqlz||RO{I`e|XoPq~G%~d1Z(%I? zS{UGF#^C4>t*XPYli;whE4JUHDw!C_-asq~0q9`X9AmvBy@+=^4~AQl7&;KuTt9D3 zcDfxb=BP^WbHctuBx{&`F)(zkjCO+UobkU_(Ff=!n{S&(Ms?Mfz_I9TK7;px# z{gdb)@gV$%i;cbc{~+NX^#2kgY;WgcYUkpt&BpPUDS+*-QfSrR`mtOV{Yz+U7{+-Ll$n=Mn>F>ij z{gL(O`Y$EDlBu)3tCO**Gl1itVXS2D@&{LdKVj@-Y5V_S`43(Gs{0>`u>WE7FGc=K z1c_NXIlKJ91?&G4MJ#NLe{%)PKY=1|WTot4`=1fjm}afK&Vd@{-}f8BH)OSVFrpWwh(RI|Q4((;rj3gLuIi!6GKVT4Q? zDpE==gWX)QTT32;5GI-?@7(s3Rm~OrxMrg#x7A!ej=|qSPQR#)KbRajY08x})MT5i z6PIK`W)nZGnIytVxRxk`HYa@#d5_GgXi6V{UE*Os&!dUTGH%uoQ6hg|F5{_~9a}E( z!$@k|G*>ZIN+}+z+>(w|wq8^&`9)9(iM~7mL4(+nTgq4l*Hkif?~!JScX#!{mJkMQ z-?pes)b?x!2u~?3?j&9AM4qdd`Ajiy`ZVbG4_ZNBZk5TxM3Exo zt>XwwXnB-D1r^G@n6rno5WfC*)H+_me)PRJy>o}^U?V6hR&}nQsGE=n6J``}ZRg60 zR;8}i)^O6g(r^=nGvZ<5Gh)YiX~h70?x>Jsw!zOG-|cy>WlT&((idk7JT z-`_m3-{eSazM*E=5#`U6JZqVrj@RzdA7<8kfasDA2#`}r87x%uN|P|0Q5ah#=V_Yh zjQe?ft7FRfsat-|Vd$AyD0G+dbA2g11gqRP;$G-B7WjW915j z+hEbe1(oRZh|)_*ZpU3b--t0QZ8fcr1Dx?9@Pyu@?fO#edoF~FF3+uY};DTVbTKK3i;eNXs&eSesNIq|M(hlk{>bFE6(PK;m1+}&T`{B^0e z3%|I`@N0Z(-njTr*m2&=vUGioJ8lZL^sl>ysxP^4jORU!g>WLfUV#ocK;Mv|$q%gQ zb*>>gp&i76Q9m@3RbKFOp!g{yqa-W=z~P#2OyaxFwMP_m$@3(Xm{xQ1^7AvUx%0I! z2*?^*3bL6%K3ILpxPHxoJ-`;2105RkwS!Y&C!=4K-GFL7qE?|v+643A1O37iAWH+e zGagF}Y(&O^{GbSX7K(mxqhC)7b?JQSbXVYF*|eCJ_8ZS#$iz4-tIS>OcY|yF$(Qdn zwaA&jYJqKY4Z$9l*X#DW;i&CxD=S-yZ@*V4DiTW)=LU2GZ1xTT5n(S3Eg%8ri?^04j8QOi)j>cW5b&%)Fn^;PxUkW?a7zz=akJ<(*MW!M zZ|^v41I6u_I&^{uKKB-a3;EW)Y7Q9SaGZJSaCd^nnN?cHv@|dG&Q>vQfyGs9y&UdX z&@Qk{xXazQrk5zbg$`*NyN4e7Ik=UR ztFpwYgrua+jZh`>)Ae<~m?m#uTUDNmy1l-?=>vg(D?D+t_;1VT-=*yTe_8nVeu{JiAoGB~pMqbS*_GytZd0h@Ia|9a9Ab~@NV1(U_#m-mxou>yYf zt7Bfb_xtPN>!0s^UVP8C-tU)6p3fV}@s>P(TX}+bmUyk*yp-*h(rf>uyYpq)aeLGE zpu+th)-}!#Mza;J^zkj79_vr4i|X_piOZS1FXql5>AeY>q_G(%`bVn_$Lj0B%=UdpX`@EZa+V63> z=fZZ)YKpxTWY-X#35HG%@)CwFvd$d`#a*`RN+yxKRXctt!kjgP4yQI`_F|ffAw&Ll z4@I~HI?wt8HX?+fi;g5P25pOAy==nrq{^aZ_~R!X%9%q1-0_d{Dxp$k0_L-QKTv{- z5Q(_XQu8qYj#YRg4+)x4L1NX*7Zj-l;2{>~B=C$@g_s%d(6#cX;JxB5u@qi^W5Oww z>zEcqfU)C3g}!^uAUb=h0DUori6byLB>MAU8Pp|Y0&2fhL2_i{u2f?}rMZw@Ks@ak zL)o{hIKB{CB{*()1z)T(%tfMCmm_Fb&|)V3J-7Y9;X`4;swJbE=^vTialKOagsR?d}5 z;=HZjkeX!OrTu;sVm#5j z`adPh)u43RvdT33V)BGw*WY}8Z(stw)Yh)1aDYD1;AX>8qwE);VMU@xyJ@J)B8!;_kS^0g1#&(ZK z4XfSm3Jv64G(4~|>2a{a`{8;bX%oQ7H8b3b(Dn#gl6K*4vyJf|8Vhs z5xh)m-mf`BS7fb`ZbQ}6Hy$XI4WdE$T^-0n&}F zjT-8sCBz6w$){jP|2UPk(mau z=tVE_ot==Cx#JxXRJ-bwK=X&BthID!yiM09u5v2d>D5|E$n++O%_RGZ)uk(<=se;t zE@g|T>eyK8I5R3f>}B<>9aq4^3VRYSc>GG7w@zQJC4%%8O4JN{)N~btjzm(=^6fIg zNZts)ER5B?Mm>sTxsWRA>o+^)UP@;# z?qcHcV7>4Oo!jA+ER<0ag~3?17m8*(9wef|+e{Pey6jP#CKL#92AE_M1lCYC(l;XrL2tw-tlw64|m2E^$Oif?<(kfiq^F%r4Q{KEaJ5Vx2j=csMu{$oF zo?D?Re4d-he1)a*_nc*W%CF}|r-_gSCwIr%J5Scmq*$c`&ir2OBusi0cgZ`TG?K_F z&8VNAIeW3a&I_Cn$zcmF3Gb|rf*euIo`U0=ELPUr@Kb^~ut8BK2Eh9iNpvQNhdUW_ zufedHK{MetH%OvHd!{pZWrz4ZakNm>8;H2odKfgVeT2D-G$hIO3ge*|McN$!c8oa) z0#W)niS~5@B3~XCP8#fUS4A6m>WiK_!L?h-R2%)HM!`yhwX7<%jP< zD7z3|66}R-h4U|~_dqP)U_S(e8-Pmjv=fqL5@MCk9udC#HHCsH3&;%j$86Qyx-xo$ z&ZtM<$cJNotpm6YqMn4yL$B$}BXOvBE75XjT=#!1nzuzx7t2|At+=@IuuP(s;nE6q zAk)bRt2s1GtD^UNm(_Vv9k$%q#o+}2B7KGS;ZE~FDc>8qO+jCbiQfcToi69jd4YyF zPzT5k8crxtjXe0fNc=OL^I0|+PplGf5=haQ)tUwtqn(*vxVp*6oL*mL0(x{ImX7w_ z+z2qOX*%9~^VE0Se{>pXwX_9OX$5!gNrN13MYIu0y?A4l+gV`G&WYU~$hjWKCMU|v zQD{>Dyh!gdr>D|PA#1_{pxO8|A0pKWHJn&%E2dUD6cIN=ZhT_6I4Y|r-+g|7V%Gnv zY}@8FA*4i2J$>yx<7eyr)nJ!qv7*|nqzmphgZ!dob8i6ENN8RNlfQcxvgR|hSz#H~ zPr-$BbKIu^uVOrenVA=i<)e)X9wx7Lmx)Na{n&D2j%PIpQTM#I?uf%Dpu9S`-UAH# z&G~nz?0-Qu``;4se_hSlI5?Tl{!f&rsCuiQKBMKXnz{)(qHBmd zq)t{xk%_Q$3ot+j1tkF}q)EyIRZ_!sHCBdnBcLr7BcRbmlnE#fi>rdfJSl*Gpo&IF z`7VmHYX9Udf_zLonm#g;GiT+r)%vpdYCGP_c=zpj|2*vYkmy*)`yveS3a@>?n`NL#5e#;2;BFZU zMNB@Zo--zFUx58Cn3iCuRj^!kbTU`}Fcv9^RWat*-K{6|kW~a7v;1?RY(X;&RyIly zEwf#!RDS+c-Na&#yy8dw919e+nx)Duwy={?#V$Vs3I~cGVLSjV6jT)XiKu@zQlSlr8Is6?BvA&j ze*!j$0Z9a3Py`x7Bm`+!1n5*sEzuvnEs{tCT(PiB`QfJqNiO*`y8BvegYr6F?R(-~ zzMu4MQe={kWKZKZ-@@Q> zLANQ%QVR--Na zPE^PEvOjhaepw(8?zBRdqr2m8Igj!#^N`?zNq6AWwa7QQK?)YW*U^$&$D295Gr!U3 z2JI$oZ{Los%rM*y+DS5NiuBG;Tgxup0v2j zKqr0=l@$XgW}Q+#dEitAGVIpNr=9YZ!w&`oJAOTw&F}sU+Ztpu9V*JmoLWG=%m_Ss z*L>7)7@a86U+Zy-ER2J=v7_tp+aKs)`^FV<9QI7VSOE@M&tA5`!d}lhB$zDC4l~O` z$L{o~Fyr@HNi(;NxYIzrh_08DRogD4db_!PozRxDPAj82Ws6VhkKBymS#T8`{~NZ% zSYM-7JKa8mokmB@0BX&cYjOK@?p*3bqE86cm>@VB6a8=K)8|IT3lRk41PX;;)9$?CLK|{(fiU_ZbHxq=SKEUY<@vnsSb}v&6Q*}t z#caQW-vkN7(Xna~gWZWlVk?O$6;BpNg0Gs(=3wC_(i#yU>%qq~la* zRd;OVu#KDphruuaLBCb?!*HJ8&{!|FXc>$ZyVY?Syev~UyDco5V#pp(8N^Xu7Nua& z1_WmY5sKj{BgpZUP>PedBAgPrfH!L zSoO*P5(T}1g@xknZvQH(0q?l=x}=#_uXm0!JU>6Nbn*ZS#^w~Q<%ppPIUTHy0l}^J zA#x{#sferqa#PW3z&*do@awi=oC8)|J>|zeoA*h8xXk!nBb;PZ-6_EPqvfrdn!a+@ z+)FFh(K%`E2*d9ar)T5XH^cuqYKmKj`}Zv^+1pq<@SK@$3QX$z_QiXe{)+mRcEA7U zw)tFm=(e~5tLD-o^j-zIQ7&rL14q7jP?~?4zK?2B+rdtOQuUoWxzuW=^2k}c9Q}hJ z`k3dtBbPc+o2ttEeL}}m=^XPWjTyY1KMenu_kzlRdDF)UlZKKcTTbBfmC%Fh%MUG^ zoQZSHuMO8c13`iHR{E_NM)gw+GIvrR#yn??!Ag}1YDk_ zguLTsPuDFi;*L(H!q>gS@i(rk@uqhFP6I>AIAHgf5njx2D2=FB#3;iA2s?oFd5c5o zbV?4_f@w{%fdL@Si$EAPJy5w9r@prs%H->yL#yxy$FEvQW`a!c;PqH8V|NRlC5z7B z%!o(eH3^XrD#8einc&eJu+>dXH6$pi)=C2)s4E9oq@W;bXw-e$h#{C^n3644M#iY0 z<*_5~&D+`8-5}(boS(A|{72dL$KEx+-g&M!zT`X{&IxRH`UaFR1-nlB%O#cj`H`1~ z#7hS6hH1P_54VDc!b8H;VND4Y)?EOezWqzzjf!eLyZi>iTyq_vB8`~xioP~pi|?<= zYoKyyW#h=Uq;vI5187P+fk&KC@A5~LI_`)V6<9=%W04JuMSeWI|ofaNYlmZU~ z3X{3Wg51jV#~-*w9-HryYydUogZgk8Frh}?R|fvg-sA7)fBGnS7!+%9u#QGPy?M}^ z<>fL9)xLV(6xsw~qxyl6CS*3Li@nz)*@z|%Q^{lVy?Dt^hySq`Cw z0$U=3^Ei#n38`Sly*QgTSn{a>WT9!$9Xd=Mbz^aOJUZT?qf=LT82CKt0?Zh%iLQ{L zWM}@;Pq1zqpz#N`lC}H zHcsyzWdSY?f$Bvv2xw#*jMe-Y51n}QKv1ZQA>uKDl*-)!(F<>Vm<4<{VNZgCDX!zl0>P8Z^wu>D$0T!D6K`DZ9Z zF}v@J8*~ej4qDopA~@8WEj;rZm^TX9}-vVRLcDv3vuMNQQ2wO z_y_x*Enar|?8Lz%VP2vw?Yh>t@wF3NOU=wmMBGXA zfR5gXJ;W0pf5s_Nw=1d62qoImCNyKU4jx+3g17*Za2^lYZ~M_` z@|xI)HDaj;tMURJbrVJ>X~X4VYQGEMZw@bRG0+iKR7D$`-u!iA>sVsbO-#y zGOgiz@}&74?Nip?98Npe;Q9kl0Kq9lYs)*5E58nuav4F$zCG2IEOp}ft>nOM_%VwV z#FoAAU<(6oTGDL1<9lR2!TX*u3*KdTIKh5WCye68VaUA%#Q@@vStG2bgM8Ma*T*~x zkHy@5jVfh&>#bWi1iZWI5PcASaQ+9s**#Fpw`Y)gi_V7mLtfl> z;q)BG#T&o)2~^c|_+;%k8bQWJnRgS>!xr*79s)qjoDlG83e`oMTewJ-6zU~a)O^0O za}!6aK4@w2aEqQ0SHyeFc7!8!6)|-Qet#$TB~7z>UR*@ud`Tlm2bD?35DbSYA2mfn zFU2(sfs<>gzRv#$?Z91@=Q>@lZxDt_4E;dOVD$1RU35qyaov!r z@MF-fR2gZ89a>wP?w-oJxqbN0E=G(oV{Wa6lz-)^_{HgguAp7q*kJ<33;gOP4*s{8 zS4Aa3MS~hel6D7Pc>%7k%t>p&s|zsR)9lnTmnUSKlD6X}mjM2av*K6UXt&=~S8d_2 z#|g+vl@D)yP^hXCZP=2>#ln>$ax9vpxSy>WbXsy=gK=ft8|818`VaJE-8^M_$MB+h zXRtGuwynYMG?ORIE!%qLev)pVLK4bvVpryU@_N9($C1Hh#mCJN_-o4T8-5&8BsMPF z$}wTsOlVQhB#tA#X?rp>H1Yi+bc9=c|LVjcc8xfAfPYWT+$l5`--GwYPotw6B~H%{ z<{p1?hM=B^*L*kke*2c#>Q;_oi?ph&+6F$6+86=}^t^%jhL4#sDa`rZt-7M~fN1R! zc?`?jZbcD@ZLS8hJpeSyvimH1J7E-871@BZFJp7yV_{xrcFcPh8Q0mtW1=wMd&ghj zS(?qaQdYXoX8HiG%`_ZrMI2uY4^Mg<5(+`CB@2)B{j#b3M^L>`bz8~q#UV&JjLR(K zQBw)#J~A8fl(O=0bBqXGD{p;eL*biGi(RPjwN3Rq`}!jG;==dPSW71tE-kcc zn#r7_00qUg%f@DnS*`^jpMw0LfP7RSARtdv38&;6F@S_#J|aFLAyMBAB4QvgwCV{> zsT5j;@Xw4s)njSB6{&Z*Yl(hIUn!qanY>U$Uo0dYQibkNaOgi&Elx}6b)LdV zuj4|fT=G`fj(xq_+EtV*sC5-97IPNnFdlEaXX1p*;^Zm|l5;zovF>kp^q*-H#bbYp zAYCOY^>B=E-xKHQhPx@~VG$)mTtE`TitbpaB(0a~Ta#m8>b z^#n->PhMJKaQ3|GOYki3d8B6(zZDfl3#tjaNrz~!LQvhtLJ{ZJ8(vh-ZUdsH5}$t9 z$W@7M3)24ZvM>vkh|e{D#||mZrbJEaSL5MpVVxpNk6uhoAwv0WeQu!riT&CDFRO!N zre^=!ih>f(bz9E%u$Xe>@-1F%sNxmz0wqI`?p;L(;$}vDSG>ngdx5!x&;7xOaAVuF z$=Op+t?^VCoChH+P_$zJC?)BMI~+xE%|d(C>w8GnI@xUV7+KapPxy?q8^5WP1P7XI zhu!UOgPO}71twTXd2V$2@iN86qX;~+@QuxTra94Yd;+hSL5tbX((N}lKN2^jq(NLh0;ANFH9&}1AWJ-bF_IS#kz zqG64tI`{$sqijLEF}vasasj!yYQ`9w%h`y)j?gbMsaO(z>ezjJC2tTBM13wNs@EXi z#?K+qswYSVmC-$V)yo$8v)_$H8w8PMmd4bdv5F_i2)tMD@Oee$DYL;kGh)8k&5trb z2QNt7(;{)l17U)nVg+hdDq5c?#ilBB7jr#QXi_%JfPOcTY+=)+kzI0;R&w(4s!zA6 zaKzA?bFTCE>K`JDeMrdMgoHm58Hz+H4&HTXmr2cCpu-Z`S<|248Vy5yYY#oa653XZ z>*k`HT^bJ|K0qdYJO64kfv>foa?d!5YIKz^W;H;&pt7Ws3GTzOzx?;kqMSyzLXi+xdzgS#dTfdd|DP! z5Fidb>rm5sl)!1b^AX-aL6?RS#eD^Q+GX~0>NXIOJScWF&A^aAvJI^6(1D8vZ&~@u?nZc-jpATx9;R*dJk?6>g z-9mk0^ILK8(Pu*c8=vcMT2Y`8{0`Q8eJ4~jdm4SG0%fsqX~?{ShshtuC!DMu1TP0g zUXE?$k;)2~Q`s{DG5d$W7nx-p{}gaP5%{0!U2vwVf2 z5l4{|on4vzj;kH&&s^taX3^@|icwLQHY1$c7LAh_8az;#|QVt6mNS9kCqYGp0gAI_N1bwCD(dk$Zh8Fad`M> zA*{8wY`>dPfmL*{H*%6L$1HB5X55N<=q0H&gY9PsT+o9z!L@rs_ZFx z>sM4q24>sFeA^IvplRW-=03o6Z4 z9MPe_ouekw)j2Vy)Ql7su0^LCCO8I8BEV;~Jp=^yxa#a#x0n`K9l)=3kz0zm=v7l{ z#`y~nYlgpDD-`FK38`w*)i0>&0{AX4LAdEH_VVq>K!bo=sM#S&=vRh_l8?HKwb4;%(G9oUqdiIj^} zj!F@mMKJ}l$qI(t-<~ZZjlVpuEs~3+rR(1VfZ#~Kzd9-E8M>v)_>7VpsU&>rPuXTL zt6Ck!P)7Hqf;5cDM>`F#BGS9l&}id6Z9Y8PN1mCe{R;Xq*;yau={d(~iUzhiA;qT8 z<#WRQi_OzubZvTr!TZWcVN^=3zUDN2RSO(;b@<7P^Km=cufZ>Gbd9Ufcgdyvu%W~T zc-5)Y`w}6vDW)T!`Qs>Sz|V89e^tDwC#(K*>wYAl=`I}ieefE(3eBcj^mWOp$9ARr z`qOZf3u5JHit?SWM(1SV?cuuOSuo@#nIM-RMU<7 z*%2+J9-pU@{TE5wX)@VkJv4s?Q(ZY>7ZGWthTe;)EC|CHHiw_O8sJpw7sQ#CPA9PD zG8|3&i(k~h=6fZ`$yTN&zfy^;=EJ7(QWanGHkR*&6U;^_8^iD8U7c~)H1AVT1L?r< zW$|Jp)hW?=r$mYIs;gxvrDvGX3yTV@di zd?b>XkRX#v3EsC%(tD&2#@HnC88d$9YA9kR&A=v!&_lu9qC+r7r%uxulOV^*) z>Mb^`P;}Z>V1j`+Gks}erl9u@3EyT3R(RpLm_pPoLVNa+L-3*?4?05ccF@pF2TaxU z-hFh${Ng(S?#z^ezxb~>_lWypiAuGu9iub$xo0XX& zD=(0_N4bz!qZp8Ack8DHQIJfvTh;kVF@TQ8! z`clpp3B$#4DSV)#XlfI=gnk~KWlpwt_%42~eZs?q&{4X{Wl2z(6y{7>HBka3pr=ce zaTY>9OypP4u0qOW3V4)6c7DgMLP-RMOKZq1XmqSaen1o$P6(neQ5cH6gI@avKLde- zS&qwyQGejX9BLylrp#e|O|&Mll{Z<&c}Vne5Rk8iEv&c>S^Q0qrj%0%vc@(d`3nD9 zp}%9U%eO+v$RuH)s4!s>DbUo{i6qC#Gl5CP0A$F7g2(Gu5Q*9v=Q~`18AiY%X$E5cv^65G7{WyN(VS&v>~B7gXnCx9ej>6&pLkD^T$&{UF!<|9>+R< z+63|K;@!@~piokQBhfG$0PdmGP*ktgI_^r7P(6K%Y`|69pN*XwV*%8n7c z?vpK_a1txWb<-SUDx4xeQN;U~PY81W>acTNpb|`?_3gsvQV(9*yW#NG!}Zl|s8Bfp z#D|Kqe~-35?sL38905B29+%h%W>wmdHjnW@HJ>-OiNX7!!Q050=FUFdG z;{7oqi8r+e_X>?+XEu+%eX9Rub33X|r%;AIpKRj5RV^-VScSg&2vkMreY{tUL=zrO zSoK%PuR6Uhoa!3>#c#6{)&?4KZF!g9K6uRU5%e8Aw!eHqnm|{sj_%(Ljh52wD?a)N3vEuuji1(k@inZ|#`v5NawP@9q3s?SW(8!x!saK?~h4M*Y!1rD>iK)>S)t zd~o^M*fY{D??<73*=620TPT+AI?sR5SGC-mp@t=M554xb`ET;@+aT=wb)atV8n*CS zTS9%&!~be&3Skn`XX2dYu$LE&a(%!Ba&jZO5iIuy#r$?tHaYO$vcun#_5W*j;9z9^ z&+Jg7Dr>*a0NeGfelP>PG=rI%!@=gJ#Rj_~H?(GmnTn4p9?AP+o!VLTj`lfLEFu}1 z<09N5Eu>ig^70}fd&Q@>(bfKPROcmJvdeR`g(!u>f7_kV+aanhZUKoeXK_%vCI7Ks ztDkHEkHede8;|a4CePeA^_`_lXrH>E0kPXQqS@H>3cIBlFOo=u4Xv79xm`q>ZCE(i zMDpR8&uc3gY-4fv$6=0zCs{O-)Cv<}_p=bT0MKv#1P$U#&}u9wN}om$54aLH-?$WS z={1;$RTN(!=0rMkzfK?9A6jS8YbMzs2@A{yV3jYBbdrU{HfZmg=>hR11-Z4Il3_02 z>FM?{ZIif;RPXvsu-^m|N|HO=1qPfUvCAUEVop&Xrz4BXUl@)VAkBxBo77-#jvcV$ zl3oZJQHT7j78-%uzhbI^{RTo?8b&i|GsL)5E`C+34VKR_2!75mI%1J4EPus|fba(@ z$Q@^DY4Z2Tjdo*V2c*Wmt2@1^SAs|vfbd#@LnK{(NOru1aDayMvE z=vUa4E#^P5!yWGV9y{RI|AbAhMjw?KeLRo+sU@K%Ju-%viv`uA)oX@E%(10Vk+HUL zv{n@;ebHGf{_}&u==fpNP1NkMA?SyV$HeZ4)U;_aSteV9kBT_4I#?4C6+63^eHHkI zrf4nN$?Qq<`1K}8_`FqFwjs2T9IM;CsB)SB{aW5wkX39&__w7JhrCgrBge6*xzyLj zeH{k93#W{D^S6BA^x{#IqTZjQj}n5%#vE~=;ZA+C6fv$pbxB;xG%OgcVfWe`gS;U` zceM^g*YBow;eF~aL^alBEY&S1a)QO*ZW68h8yh!fJ5Qs=Maw;!!nTMVL$3L}HobkU zc}dUDY52p8_rHs6dSgUcgicwUjm+4}{q}a^&)|V#_?7|0_U*rc?myo6`Y-Q`{Of=q z3p)q<|K%u=jf$rNiYE3o*(p3UDA+o(*8C;o)j|bDb<}hjb=-2jb@#Pl!{fD7DgK%k=0K%=1XN?Bd;^~(|eXg(ji8t zhMF{vr8K0G*Eq58nV#YCb3eG^VWhjyU!OY({2O?y>~0GXKp`^qUITd5Dy94v{o(Nha){9(CO97mv z^y`4Sy9NgHxvz;7z^$+cN(g}o~Qz*w{in!GYslsS4X8`C4A{Tnnl+8hhu6*LvsX z08=X5K8HYxlix&HL07wQ9A})-%EN_U_k2L^_f#@o`)g{0$rh+S?$(yF{nl zSu&JU-DJ~}uGoN{OQ;}BL1vIT%igRhcu{bN)NVD#UU3>@eD}w%SF;>Rw-T;r`HPKB2iX{oSjTeijOVo&*SH&F8tc?b?>zEb~~-aNO( zKjSSu`0z5uzHCdtCv@7xlF)0x8x@R)<|I-cXGRK(?wPZ*JEqXU!|nhf(@9U zl%nf#r8SjlWj|H`?FRmF+qdK1=U#}P4@Q+DZ+sQ}*?;7tok*G~oRURNPd7=$2oY*s7^)(v z`QR99g6sMihyL_fCz~}lo&f!Q$q|e)rpvFVHgAZ_ILq8tztIsnGt;bW#r!B`F_pjl zgKNw`Nam=*k=8iI#jaq;#wa@K;`^%3K*>b~2YR_8U6imI3lzm*@AY=z4}vo7>_9za zcKdx;H)JK)e$*9UWVwAP)BI`=(7)Nz!&5OcyIw;1r zp|BtlN`&0ZLT1uXDe&my(=0?uEKB8CD%2RZx(Q3`%$+J%6)vk*LWuF`xa$1@1|u|a^hES$I3mBKrs>O{4>)_nS{b8(5cr{_1?v{1}x zH3Rxu8?nYL0%m>tE%YPGEx2+IY%Ky}9h)tGGke-+(m^F*muHNAD&z7d86~_RWZYf* z(0Tu1uJL1ojYJKb>XY@DBLXg~gKLuo_KVC@xuytg@{Jp&&;7HjQ7T z`KKwQuIh$N@4Nlty#rN~`@qq<6O4&Dp*d2o*P=;b0XfOMZl8J)Hv3cIn<)PeHtyL#Bdc zCki{&XO3W)lFOhPBNqj;5htP8EKxavHEZ5wEx926j){f!ZwP3yg6`aJAuzcP)t~IU zn9_;O46(x9S8ct*SFplu=he+8=`OU>(Yo2U$qE z-C!Ms-#2FWI?r*Nl(zVcFc0%VwEHf)g;{49lsg4Jn+msHdDml}C>SImbiKN)ezq~R zk%Z%1V*a7D;j?&HvJ=TC#?)>&= zi}ozBI<{yhjdH_NDw%ot@9+^dd>OKM+-}ok2H{T;dkLTVPPfW4d>6I?((%*jol5ohqJw(g`fHL(A;i5utQa8}f`Puexe+Nj9%l*hYrt-Hw zlPN5I&>PAyRLPAYaX;E}L~W_FI-kR0ub^o9xdTo?1u#{8N;n`RAq-52B{9+0=toWp z*y$%DvNH_}`k`hGV_X;IvNmg#7stGG+%Y+Va1WOA3etU;{oHxXOAo%ZXZB1NyhGCy z4~M{DOC#)>Q;=)%TJ3YYG`Q#>Sdd>KB?LAjKXncZjH9>;-$zje1mQ8F`innO5tlAg z4V)_N$0pzX);939Y5v+GcFwdOG>uWp3QFfS{aQR~yOAikRUS0}TKMoZPb&Zvoeiz8 zen+-UI%?cg3aAMk-iU$6cmhHi3bWoGcamFJHb|Zk(djMd_AF*PEMYn~T_)p0!i}VA zstEJTD!M|f#iCaKd%-I<$qAy03zPDn8IZP3m!@pkKJ*Xn3oa6(z z70E308uT@N>4)2b10DQX=IHW%-K!r1u}P9f*5O{4M9L?cW-bb7HUxVwXe-~8U>0F#4Vn7>W~X7lm^X24}cWJh0jp zm`7xpYqparksSRI6xA6mFf@q$S}n~wgtm8k#K*du&6g#-0fgM3n?U$M5_^~uWa?L* zxcv4Nf3t&+N9Ur2&|);RLn6BGWLXn?j2e_^0+)Ok#lH?{!>J+81vp;7R_?nx#4`oJ zdv%R|i$LUffgZRA^*z9-844P$@d>p~LvP)vaaLxd=Ks_W;snX=PDyrjpoI;82ViaX z_};C)sM}srAb&%*t}=uD*wx-saYf2WdD;!OhCUz}RfPPc9I;DVOD0Q5 zcC(*d+gQw>2ju<0@AG1KIx%*q>Jt;;+Qbk_3p&a-<-$H8CK>&h8OJP@I#lTH`1tT( zfzDGlim#G?T_wjtBRKv@yG8jwjC}=Elu!4tsDNMx1}GqcwcvIk#&&nFx!v7sV4;MF zA}FY+7}%{?fZc(u2r71qy(-4{h;N_Y_x|5Kdd{-X%$>P?XYMm|y`;SzJB`KMdDH)h z<861)(0;ScqmuW(81b@A-U{!27rf03Ov=PUJ$K&6j1^z6`%Bt>Jk;$(Q!bQuCzX=7 zmAi2WsJ6WE#`C7_(5j!>53>%go0R?9H>&lAkpspK!%mPMp)Bgw#sy6Q-Y3E+7V8Jq zW4es@5;>lQ(DQ?|>y*YcWLms8p6IQWDn~sB~ z^jx*$K(EJlT<6;_N+8uAl}x``u*%XtYS;b9)EIg1$)aG~{ax+XYUJTdC`%`{uUKHB zGq7WaZ+rdS-c&Y{Z$9;);Au1?hrg%JdSy(9p3!aLsm!^xRmd1t=kj&dq{GU-!yYP* z$|BYgnv$!u^8Po^@*?n+n+lUR5`$r5ww!ww|MdZ7PH~U=qUgDx(-(`n7}s8Yf;#wG zxBk-on`1AoSatGPG%o)A0OG@F;=|OL*RAXO-|o;lu%Tk?nRSbf>28f^sJr3p0PbQFH!nEqWgU~X1Mc~5Wkc&$KnsV;J*HKT+UPS~9J zDs~=z)3v01Z3-!48si2-F5;TI#T}o&v|bk130xGjFi`lBRiFK+YB{L()aa-ZK` z`}c_NGHrFo^MyYX2bz<55AK>OY|5To-35*@M?~^;;O=zMi5L9nVI}NtC*BV22CAEm zr1dD_&)Ihp?yWu}WO%|_gEpClS_U1q9pP_}+gysd(Q8ujo@ZqRZNwMH9ed&EJZ(O7 z)%B&l+TR=7abeHLSGQ_+FK_!GDGAk1edqxqV?xUDkEl)Mi|mN!xsT+fm2>5tCn#6% z`!&RS@Mk3V5)RV=wtqVRdHzxIs+-PDM?l`9SMspX>%r~D4mP!2tbcpzWa7XccZM8| zF$Qul)!KN{XV>i^<+;LV0zeGlzjyZX?ePvg41&bV5>^8-^F*N*gSWNia6YR^vo>aC9+ z^&8&%$I~m(Q@`~7+39=VO^q8!c(yUGYOCwgDD3EH$K#)O!mrR5U*5V`{^OAflACe) z_UzZkC)B?DIq3-V_PJlJ!&?9R5o7dNjKR2Bvx*8!7Ob7ODB*SJS@`G(7b?6tN!6U( zj1$C7n_9G_tgs?~VZxPAc6ff#g5@h>SNqGnY#u@Hk1w4bzGd~w4Xfhj7tNV8EjMd+ zFoE4{2%oOYa{6M`ORRgELcK$VlQ*?ef~mT)%|UVk{J_93L)NwP!D1^de@uXzMwn4`?HDrmrjfwD{pQE84#ZvN|_Og zH{WlLarO>G<}X+{w>WOq%Dm>VB?(281qhHk*X@q&G~sWMdtM9;KNITpM768?moLyS z<6`^~&64J(@SKGa_hyAy*30qDVb1VC--<1pwnOdnX1e0Lf7I9HD`Uq?M*HpYGc(l!8NXepW|8+R z+!q$2)rN#S+Jt6@Uc}_C%dU2BZys#zJwE`DaB+TV+{t-7TtULp=F2l@C}q>9_nACp zf^y2Fx!c@vcBjqdTOKIiwzJ~E-h{e)q1K`P=bpEF6H3Y3yeYQIy4>aU1?|4lNsDI3 zS7a9CPLIVkk80LTofO|`($X*dMU{(zmSsyS@*0N z9XeDmEm_y0jTwk2m8&PQ*wOHOT@$NzB}#Yly}2NuwDC(KB3(iULQUp+;uFrLrnW>KdO%(;zRaPdm`)vnLSep&VvpH|Z8=$^(U_LP0f z^ICbz5XJkkn{1iY$K|>)--`M;>UOpt{-jizRvF%%zZlY$x{o@`|ELwVd)AuMIfM6; zhZGi?iuzV}qn-H*x;R19O}r&6rCVsu9It6l^{mkX>A-`CrS%QNt2*8*Sbp#<=v&u` zt;!!S&HXDb{JvJ_|KtJ|xIgM&&MQwAU2xW;y3{leaVMwuWXxET{CVIy#{=dAZI3U7_?0+xD=#*dtnbrtu%~VO{ADLKZqCm0&0SV{doEkk zr(tAc=Q-`44=d?eT`O0KVuydEeR_9_H# z3SS&CVlcR0x@?as?NQF;Ps;`;DoP*ImXXt*och#=d{)Pp-7~{*IhRfY$FCYuw>fdW zNb{zMYSzpyo{bUP zuwyH4kI8FgtKTtd(DuD&i`sAGw{AyHMR!WMHtOxF6SU}Q3#|`4$wLRu>Gct|!7_O@ zVjp}=Kk^Ndeb{Y@2}Q|xSu^0&E4;t%dVKaZ%e{@Qimd#+Y4{!)qvBwpS)0T)4gEW=Tmuz{Dqrur52v}se3!w@^)vX z=%M|2Lep4}l%#LH<;LiAgYrdh=I4%~Qy|UQj4<`?!+Jxucyn~_Wpmo+d&?C$XzaMa z;M&Pkh1H6cpLd=)_R+;ZiA}AZxVH`cBL_YtBkJ>tuBBOD^*`U~Hv5;6##;FW$!iR# zp%*I}kGMc%7hmv?J6iPwF#){Dz`T6n;^!KrHT&)+o-1`<{ge`DU#qeAc7r+X<)MY1 zsF|M%xxGipYcA>*G$(tn7W!IET`>7gYMvyjOk*K`*#o_HrWd#KPzOH>c&6yWf zaQEZo+b42LCne{We96z*;jFvTPIlwsr?b~KpGSyy7Ik%&jPAT;ZydaK>8bOf9aXYy%GCI#hpmp59W!h%n^&X(orsnl;Fk?uoPG1G6Y-vD zn9c9ow`c5uwxf>S?K!8Z==J>U?uh8M>H>V#$lkv|VO(o>M~UqMYHMnb?ziyxR=v6ZYohZ0?Ng5=QDN3WeAmO6&tAj=kU3 zITYL2^C{{@&6OW3Pfaf1UCDb@v55EjnyL2A@_QEse@d$2SFz83pZMg6ti5i;=S9JL z@1s8S`TTY9wmVIhQN{N{KV?5xT$t*rG=6?L?bfuohIVhBFM&RJ+&gkquetkg`XNEe zW&vkCVjEOyq-9A+iMh20jWz3+&#?W??1;d^%J z#jjnsFxp1j^}hSUW@aFIv?Ua6Ydr}v#F+i<#K6{f?@#?Pp=|lDi&3!Bn;CBwsy868 zfzjzO7z6@K>PM}Z(i|(AU!i?r4jXj8Z0+Gg(RUjbLwcb;bU1Y|V$vi|StFOPhNUQm=i%dcF- znq8jyT795m|GK~gi7W0$j|}Q9;`8~G>*@JJFE(EJxFhz>%?(cq=jF{neVD(4HEVKX zY_0mzDrUDWt03Fo&N%i%y(hh~>nwX@KOKZlh$-q>Y?{?@=9;DY2ws@KTOWCrGx_?-RcL{Bdl;4??T1rxV)s7=*vjW8QW1R%rWE>Fw2Pza^Cwl3VM7Kfg^5 zj76@9a-M%s%KFljU`-_@=T4aAE$_Wy(Kr6$vUL8YAH-8PA(RoeBa!Nnd)j?H1n$`N4cF^NFIk#u`G`4v=HRu` zg|elfjh^f2x9N9gR7-pJTq@XyFQMOQ6Bj55LYMEWS#+U`(Y|Bo^~3#UiO-KaR$cd9yI?rd5vEx&p5S=I{U0ZV?*O(*F0T7^fk zgb4op1;%{X&W#%@_@^shs$NJOxvN&M+flh^>%xeUv+r4KHT-ir*~4LC&gHEzLX`GI zx$w%~+6w396AQN3+kO~pLEPSQL8&UW4q&-=Xctp6FjKd2?(tK54L(N~TqWt-N25!I;U}N6n4XGOn{Jlf zi}mkbXZLMTY#Zh~Fnx8db9KeKPd%5)t2V}{UdDZU240zi;1w)bu)D`~P;}e6jokjg+<$xeHyRJjHyRA4N8vNt6rToc)NYu4l#k_yES?4h&Uz(xtGFk_T$glyWBS9+|V1nex?gAYt9x79ebR%NIH&}6Zw?0r1IyF zibV%H1zW#AJvEy8v1Ar!d*Vgu?edZbt3~Ib%P(Jj-!nR$HC$HEz@tk+%IrQ&f;W(4Q3x#=qRYpdd2J zv#ap_3$f+E=tHe13}4)FAn4ejIn2m+qh@w$hTQzxT6E7XC&V5+PZt;nY`pjr1_F|$@9BV6O+4kA!0?Uu6ybQWA+@pXC4(~>sHpkPWUnD;@7F6gZF;inM9dd1nIX4RZjZ#E&jpN zX@@S)kPF_KyI!d+-+QGta!~a!--RpL1EFP&BF3tW#P_!(4Xw7U>K=d;uBPIcecu#w z_pFCUjtRSYIL&^wsrhQ{inyq0>wP)$?ngg?`1wU$F1#xo0PjxN`^_$cJ7;yvNiLC_ zEIY0TmP~Hju;`*-Fd~>j7VU11dQ)sdPkTV7n4Vvrk~wW#)4hUSAHH^LRG;>D^51U^ zS8e8ZY8ZEP-&#&nVGgfUte$@#;;oFizamb|R89LezrW?okGaoho$t$AG!1e+Nn;8u zw3IC$zxfnzS=Hn=%uK@8ExSZry823oLQ}#_<9Iq;a(39LhVRf&0-{%HV;wBWNX*+9~$g24~l6^+C%pB?gk0_Iq*$e z!u?x`*$0mE2Xk(=KYMwz_5JzBAKHW>?WFF7vY`v}pv7DF19>^^2wHM0BzWrZhLQ!- z&pkd%j~U3@`*p?C*7R*&>~WGSq`C_}-?`^>R_!rEMeyy6BR4KZx!2!X*g#`svU6rCn|`=+`G- z|Q(enRm#@sYy}9xys-dy6w4rxsMPu=(ahVTbmBxnSk$wG%OMsW~qTfUL zil&{{Lx=-WA$^W#31loNPEfUxsS3nMSZWnypX@>_LBozW$Wa` zv~{snW%~7x{1@ka@j!DhyAryZYVPckcIn>9?2Z7{O=(@Nt|?x&{oESkxKbYJ)wGoC ztqp^VR|kIXEZ@NWwa_rOZbtG``iUMD$dWA`ksbHs&&u!gerI!ZMC7FAh{*PncI|2( zy{qm}PV&2(o1R1*nJ}tN)~n6Kv^`x9+fHZCowcznS1Gi0k!*Oo;^Ht0N}$1}l-!lx z8!auX^j=FbrGZkrM@4_XHdty8^~@eVGwxaK&)Vequv4a0ZNJFIKTkA`J+t{!7PAvF z>nvzZQ;!E&(xGo<8+snUl%$^Lz9H`5ij$u~5B&1@wlw^u7;v4arX0 zA}=5P_0%~wi?gh{&(Roha$XUae=;HM2D##Edb@#&U#mL}yo7t2!Lzk@cT0YD4T2PJ zis=Sw9F%$-x;^&mu3u0oCRdgw-*a}ffsrl69Gx21YFkp`(6>EO`<_}7`!Tkdx0YJg zJ%7j1HJeB8-@5vaW_sT0@`E$?RfW@HhN|gI9c0WB5 z+1N((s{dPx2D(Sw`((DdbYjG65YrJ}i~;cI+8*;_H#u#7{Td zPVC>idH1x)W2(mGYWl;E_Z!x&xEr>1@f~m11YP^#y%$a^E8HI0!8Z}?GJKht^dxc4 zk)-acT3_%!wq~@CCmv3k_psah!&AB^PkN1%O_3JZJ|3l=9aoghp7nVJ)c0X}%EI)3x0nL-JbjnyKvh*9ea`eT10-$sx`aQ+SixNKaeMl^1oDGZp|Ou zIH>cea(gb}sPy=m9Y|?hzdO~}46iPqzg~Zw(L3<9(+h*k3vHa!CK6Q=g&s6R{XOsN z^|2Q-ZY592!cKm@WY#NxDeS`@)3Q6;V|kiqjJ5iyXsq1;^S|yf%tqdMmdei+y;!pO zLdUJaNJl^ag3j1}P5tY!U0v%wL6YLvefn_zP@BxTn^W($d)H(0m5G(5rr4KD+HHLq z{-w1dx#q*YR>al^f7r5o;hUFbx-)a8&vYb+n+f4Te$F&ktfIoY>-dE;p-z{83_|#1 zad4b9*0|ogAjR?huqCx{h?P7P@*XxWfrxD0E z$PZ3Z#42Q zc)vn{d`4=heQ3zZHlfwWV`hX-PCjhd4dB9}eL{{{6Sw9eq-~b6f0g>n4rS;V{JuE*6iE zH5i%AktX-scVEUp@MrOz_I`P|u=OY7dGx&X zy&AbaR&LLeer+gOtR+s|`Wo!5?=a)-Q-4CnkY%g^mr(~Or?XNMN4;y!-Slh!;P(A2 zgAT>UJ*F{=mkoM4V_*BpDQs2k?yhr|T)W;kOLsea*Wo@p;uf`!Uy+-wee>dZ(`HEe zSjOOcDfD&Ax~yB5xMJTlMtlHBWWIg$WSo6%)^(S6 zX)kV8UfR-^$t;V10aQrt6t7JC;p`N~EXU~lU-WSX` z{+fGzF0)=1WE^{yx$esQ8{69)d6^l5JX134N(z66%FsG{#OVIfCoW$(JbM1>l^dU= zCV@7r&8{l18coIN$1fTZBH-?}8im~FUAFORF)MTM!6_5IVtdxNKUo}-a^~?j_EF7_ zE{|)v`{GjGHE65lFXi>9_d(^3&)@lvJK($XSEMg@t?e6X&4R2&44PEgkQmr@aovmo z=T+)h*u(u-YKHF<^nVl&$_2f+uHx;FGkbLrgYhoQvn8gt%`%qd@%T_~*cHW^dL z2WQ^7bZ&O|gG6rjnA3fFJBGLJU)XBFaZu-XhM|qb!AhS##;@G5mbdK$y=&R)^0=2w z^|wy0&Y$(`A2QRY2_wEu@*|%g`gLt;ey@Zr_|N@fi^ImJd>F9jso~tpy={*z&oJF> zD9^nomW4qb61TjOvpsk|C%g4|3tQ=_Vc}jpUJ%zZMkAe zx|Q^G;tAND6y2i7YyfsB-uT0>$$A@hG(CUn(Pi9(n5hsz>?S9uSVK=X=)n?@A zi~FBlkU6UacQ(ALb4HHeYCae?lk)wLdQ;zdry2|GA0wCbA}vYTqAueF$)8^Q?2?k=Y==Fu^upldx>lq)a|+{*FV>Fjb*b~C^_s6q zY2D`jm@+@2AG+16>si9lz)7!)sRuMs3OBHGjFg=CZlpmyVyBc0B8qnE{Qv=!|-|-q;usl|Fy#vG=Wd zVf{VP`U83Hxj|WUzcY#L18tel_J2Cr%e4RWRn2>I_j_BeX0YS$9uAlcFFG2`szB7v z)Ehe<^TrfT#6B9ZCuIpc*f=)(EUW0IbW!!1+3=jQYx0dh9X&22A877+yxsm6d#2w~ zE%pCsPTUPr+(R(6Qip8zgyrL$VtA1Q*7H+mkH?GG2YOLf2D8#(@AOs=HHI(nFS+&V|!ih z`~YPXl2WDozV)&hJ)RZTjh?fnP0Hn-yqu@G>JC4$Bfg$VnL;`181Zi8*Ee?uJUzI$ z?oIux;xd2D85Hxa^JT#TaPIX&WUDr^onrEIGN(@lj)7oTNwuaMXQw~SYFf8_=ZeV- z;TwNGELuxk{kiJa&Jh#**N}oeFH^O5OE_s~o0IBc9eW6Wbe+qKsbwU*c@%eQUsgNL zu@&_P?)Bb6WG%8pj$c&Mx}iM3f7gNAZ{`I*PFS$8>zP(H>Q1{N4XT4xr1L`_JUiBp z++j_G>pFgS_O@l8R^{{@t65_y=Xu8NB3z*?e;TS8dTe9U#cuBKMT@eC&rkH;v*K{4 zt(Dr_(%uh)f$9guPWi9gk<06Ms(1FkrOP-lxUTI6#Fa1eUeGMRCQ)|t!iQZMK9`E{R4?)%l8hFpebLr3+5kHB8NpA8#zf zv;-|{VoTlJPQpUxN2cddsgl|D(6Mg()63DYlkx3gTsf`st3YfE6n9- zn{Z(k46VXXiUQsgR1S1Uk2n~4>OE7yT19=m(X{kz@94skHNz0ObY7Tey^_dy+;nZ@ zwUIma-l@-@Ip=3}^wqVA5BlZ{2d+vTcs6T~ZEv@}i=6id?-i~lT&U}sHehS!6+Y^X zJ*j8>u=KE{3+I|%-OT(n??dYKjcxmsr!OA*jrC;M+QR+d)5wHdeQw@fvG&oNl~z{A+WQ_9L6)Ka)ov764 z4Av1zO$QGpB^uNtk|anLgk>ja^ag6ssSyO(LRHYD!m5)-jf~Flr~55-3s5I4(Qh$Z zUFrT2NeZCUohBXl{`+TeQesOJw`oKY?)QU4DT|v(usJn}=wXmF6$Ge}hsF+rLD5Q; z0+pHwg8;?zASf^tkp{)2!!YSk=-;cgC5`-DxGlq})~54`hQy`v6nu;~7>yIS>+ zo_3GZ+=7`}1=g4~mf!W#fMKD3*$-d|i}jDLe>q!=+(6;J|K8K}`)2!p6ix%0 z5BMw0U2c*Wpyb~;^OsZjcR(aBfYM)rvi}Q<{|ylE03ZSUFA)JqAmsiZgm;7Oe@%Cn zn`p57FOvtFX{;_ldH})tYnh?^cQC;CEmHkk&H&j4W}OaH(lclSM4$@YNI)6jX%KiC z93q55(_!dz7%B~dOou=+!2i+mpM$r+AOPhSf$EU|1B`!Z`5!R;ZX*5{DF4>>Kj5^C zU~>v>HuK0o7Jz0pdw^Xv#qW}qFyt^q;$W6SWw5$!F8z={nEh*?{{Sb{dn`(;!eDm& z3*mm#?zjHp%u0e1 zu15&Q8o>x0jwdk5>3kgykJFKG0vw)(!x6z~9?m7A=)GzRHmIOreKaD=!@}VZJd>G? z13_>tA3P3d|NBG24IJRd4af0*cLqK@93Bip;D{!&NeQEyaTH<8&p0huiK77zbvzOt zCnDi>A~IfQqTqQlI>D!-;B`7C5oco%d^Sc4Kq4N;BI10%|MQ4A1eb)vaEN%8kc{UE zNq(P_;&-X2evI}X|9vDX-lt*{aS$0Wf{daU@yUE!%kvh1|7iHHo&Lw~U<`mLkQNmI z=;C}N{vVWaG(L{U;sdCWaEcb}skq%0lKJY1fGovG$jB75OF#l0r(A|2mr{+0+_|(1N0};_%uIHOrnrrB!fyu60z(Q zJ>Dk3qX;Zs%QD-1`L58(0NH_UXU*|IT!&MoJ(MfcvcgN6J+bvA~{t^gm`E) zx{XCA=yh^E9Uz@U01>;{Vg}wB0E)_zS^z>hLWo={0gA~=pnkQ~jFre?db`5nb}1!D zuMEzIs;z2-MrucEsZfv>&LQb#3aWu-VHgkrw$TP6n2-Xo*{+dU1`U$ak#x5l>Ogt)bcu`VAk%O}n1sw>$*2J-nMt)F|1gDSC7G0FIvFm8 z((%*)%`d>x!Avv;tmRtG9E@Iq@XFOPh>Ao|3eo-`O+y8fMNXB~NJA?zT9V9W$68TD zpl+%HZEKmK+wA060ve|bj$sIW9JSG}65D)e7DrE&A#55V*hl1uv{rafk5ogYJSUUq z7F!`Of`%Yb$Rt6HCFr3Nj3yt5p`*3{RFZiZ6qpgjV7P3k5gBwltwE*}i_|EnRJ?|( zkw7qd6Ul0a0eJ<5gW}}F%zT_3WV9PB78Xe34eBu%kP<7fc$nT6fFik5;FlY*Qb+)1 zQg8@tB1ELua2-^?91AwFU1B@a%K#(I8a0?n22(W(p#$ru`ebCZLcq7v=s2RtNhkRk zE+Q7(0#L^{xVaiFR|@8c7{mbE&hnbgA_FF%aFc>khleX78&Prwmq!Ie$fp%L7%Fvu zB@%(uSUv_0Qn-a4qZ)WIjEq)GTL9weN+?~9q|oSWt=<9m8DSKL)W$`M!B#MiqvSGR z4j9($WMkP5uNfwW3Q&6BT)09)CMYNZJdCbVX}m<6mf!&~AT0nvL4;j~w3*BX8JA46 zsi_nM!^O}UkpM@;TB{zU#!yWRm=O{b0&nv=gLJhYWT#T;B0mfxXXBYPyui*T`XyeE z+S3Bi;Y4yUez1jzgkb_gF4HM-!NoADgGi7QOf)=7FB2jB4z9_=5a|qNv%w@0&|yZC z2nvXqSf`a4$xKKP;{gW&Yuf^lD>ZXKdI1`cHaSFU!mA`?7}6x8OI5yrR}7I6Z48W7 zNhWh$Xg*u!Vk;nqAXF(ta>XtuicCkVL3)OOXkr*$e0U2$s?nfF@WDZ|92vww6=XV6 z0|cH>0TL3Rpd<*3hvGs?bSAYeNO9l-e6tjx<7o9xA%-Q=V}Mx!5~mH)!6-{mqWwcd zgkG)1idkYcIN*X~0$PyW1+{`bG^rGfBa8fEm)4>N$!TmY21mBI&|0jMfC8f&G6x6= zrde!oiyZ`1%ZAydF6$o}!p&&A*{YJV>^LezB7yVa3Y`rvw$NdI9>xylGQn0F+sjZx zz~*4Uk5a+B3_MolrQ$Rk;DtsTUyR~1g<7p#gVnSE)QPn$i%X;7ayTwHQHZiT#b`j# zG*WaxBQy}fI64ET7gJq$gP8=9dV>rMNF#OXQT8{)U_GAaeO)<0pG$Hv` z6yM-PdBAig2E!Bvu}ZJqW9QrD8o5%;QJEMNlh{vT8<_-xfX4?x%tE0Pf?`l@R*MGf zVG;vS2NbJhkvz&4fP$chj-b%+00wl62*UBhDGaaP>gY_^5YwHQST6O>Jd;XnwP z-HwtXe+#@yuVuN-ASu@j7qHn-ycH9GtN&OVxmyMm>nIi@FTimL7zz=VZ^F@)PJvYg z^@|ZsHqwmo2jNtxMJZMWBuKhL4)@R{8ZXEtaT=U3I7ZJ6A}~H5NcLxj+<+Hp_0rjV z2wEg3yU`qshr$f{gkB6wVc~lXYO2BjCPSPEnL`iu;ZWRw8qC&8z(E4kC_zgFRwz@5 zw(x~=z^k;N?6a5{e22r1QgY~IhEBxwp{;g$-3?8MOF7?xOL(^~!rHb}xX6D3}n#Ys{~6~K~MCIN7@W06FXN$B_D z^d_-JfTQZ*RD%yCLW})85LG1+q6J{2U+WU~nz16U5UjIu zm}I0L9JElu0h88@21G`O^n!#ozksTPnNV^y%IPp$?PiNjf%YO1Xa!tk(XmWHJqgbR zxnO_7IVqG&2WW#Zf&j8{DR8048N_l7Oa*Y46phFhdWk+dNv(l1+#C>-0uNZ#Fc;LW z4oX}ojNGZ=+w5Wtg=dyf;YRWwISc})Fu?(b7UKi6L{K|d<(8wwGQL3pwiWQ2uMv(KS%^b zVKjlyZAUvOEE5z2Xo-r5;%Jy;i`7cRc)fJEl|q2)07s@UgUk-U1}^i`NvJ^n?WRCnm>}GZAq&{}fZW1$VsUtb+iM~Q z%^HvdxY?Sj0SXJW%uvn6;2?6FOski{{2B>ANb$fO97r(04rtM0281dkGqhSW9wc!i zkys4Ohf?7n8nXmS$4WeUke@)r>Zl+M4@~g>@p@L5N~~2-Er1!M29!v$7t3=a%m%lb z>=I&yZXZ{NHwg833Rh|NOUXnGSP8>=oG6SHkO8sX=|Gwk3XL2KQ^V;f$nO^eQIZ@DgL{Edy=sZq zDffuIYK2fG^tlmsx{?*Jq5(Gv(twzB6d1(wA{j8R6(qqzg%bOp8Oq#57gd7e`$$YW zn;tkfh>y z=sYgQjsyuXi*a~~ zhvB071t8$fW1N-zo0ev_68O%NtF;&*1+bXT-&_UU0g=H6{3=MsQG!hD~g~so54t?G61+x zwo54lPGgh9qfe zRI)VaMC!oEfD6pBfXzxQ*{M*l@nDaO&nIX^L@-06wb;1<9^kZF0O~<@ttALjz&sK) zTnX2B;O?MY;Ac}ce86mwWOg#cU29^&h36tUlsjN=$Dl5p*GDEuA1V)&2elbRZa9QC- zkbvRvB84Cc&W58x*+QODX$C?Zh?Q*c1585(v^s$hV?!fhD4&$=qPn68NP&IznOf>2xLo4vxWDm}HCq7!4)`NEf6Sv}QgbV9=Yq zO0$(~P&nLLk4WYcdH#qEG5|PqE?ucm0-N7}9NHvEz^@0G2D6Ea)M=DVZ%~f(LAg?; zh38>9;wHJFTnd?tN8v>{kd;h?0180&oBxOnlgab4El!l#0rL2m8VSb;0^$lY z-Akpj*#Wkk3|5HALNpyl=P=AzmPRAA!+^K~j?^nSbfzNc#H8_@y~PQ)=B zMgzly27EXMgaJsv0%TJqv5LI`s}4t_ORyfE4?_&#!6JYoK_%GA0^9~tzztx*LHutO zr9)(PVjzfV0cZuxQ$Qg#E5UA?4|quO8|+#whe1UmH5%ZVnGI%f!AhYQ0OFu&$xIXp zjP#+v2qf2{1sk~vKx;^JmY4}V|AVrX24oX(CO@F>C>9uGLlMb1Es4UUgLGP?!UM3t ziWI8@WIo_!2tle1$CR*42$$3sL?SRysn_BKmV^?o)JXob3h>G;3OW(&@sKDckqm{V z>44w^D*}Qw6PN)8$Sbia%ubD0f`Yq^cDd2ub~1z-IuglnNla3kk1wYBtTHGZC3E~) zJVK-(xj;g(0&0r18;t@E)UCIZ(Nv0o0cTLhk7 zhe5$>F-V$1&eQ@z%2)VgI*A6NfO|ZEaWGm$9IcIS0AhlcI0&jW(Y1W36$#jEIxRp% zsfa!f9@YX-L6IYZQm{^eaET0n*)w}ZZh-4{w1I5EDFYa{O@v38^gxKN!P1NdCn5k~ z=V1n*c9X}VHR18{fYWD|2Cz&ztp%W&K?dk&G1F)mo)iw4H9eec0z3*xB>{u!LdkE+ z1K#AfeEJ1Ju$2kOCdTZxd%;S%4vf+(tz?>(2CSx)&Od8C6I^0M1+@yH+9CER2^s<1 zj3*NV2AM-Bgxd{z6cxnRB7Hus8R-yXEH)U#4F@Raq@qwhm5@oLlLH8j!6O#Ct$$`n zLlWJEvYqyJn3?&NX(J%x-JO#<5IA}q=2#F(^+#(Rp zFR|+Ncs!7t-~bB+knWGQ0jUIDjvK8+LA4qoTnJ@SFnlzK?1vgGMuH6t@~bsmh)!nv zjVhqT0K@*N#8$q+rjck(STM51ZMM+xuY9N7``hmle)~==pNOz)d4RVg>Y)I$RMsE> zlqE+?!3L6rVYUK-9W((n1_@+;jDrtQZoU-^rSfDpB0>kS5P}0UEo2}N@nT#&mw;@B z2>d*gnn!U18cxQLP)@rUZr3=Z8epA9;hTk?{$20~p65*w(#X#&1 z1XO01(=G+#NfQ->`s0^H1UrogHhH{MBuN8e5d2cGl_3Cmm_`iLX#lp4Wlj#nZL^W- zbb=$GaVUtu?i`1qM*{cJ!WcHI5G#OEg-SOwD0H?k#UTlL5Lg7-fCsjBq)3(>4@8{+ zN0Cw~){g}I4M7P334U|hh)k6#o3 zmQW1J|6}jXv)w;9U}WU`K!iY`g&44Zh-({~%}w2zw$hdcVto z6|2$)txEC-gu}1udc}P6CjsCT4H#?}Dir7i8imHwDHNMP4pD;#b#m09{`6Dq*?czF z+Eh!+!vte(p^C&CTcDg-90=guYw>^;z_74_Yj~nz&QuJIp$hUOp#xO40^uNFbz+Db zM7aNtIv9){lI#4Vg&aq?CSa$tF!tVdR!p`j+XjoWgYi%x3#dIBO0o6|^aF5HjwRIx z(869MhL5I`IhfDWCgW9ooah1e{>%WriKjWp9Bcm5Sfcn?`l<$)nbOU`EGIu-iYJxM zQpIV~xjZwb6&Q^{fN)$F4=xylWg4UL3}-6b!QM*S+Kj2qpjzOOR0c!Uk>UsPv7*3! zYCQ^!=x1xeuyJ1*)UBcjRUx?4iGO*9M$VbmvP3%7%<8jqUr@G3|p=b3j}bDZ|6t#al+X-0#+;A zlb{I(+mQloY}tN1Yb3*#Vd3G&Mtj;|VcJ?u4veVk0N8&7Pt{*qNQW^@{48`VD4Hl= zz5~}$6X!ugnlp_xDV|8QtqvOtnAWy@9vKPc0r@wXP$JFWgy5*jb5NzD0&J{V0eqyb zud^A&%M$gYg`R-Yg#)grIp9k>QSmw`CIA`+nE7~F*m0SfJg_ehhJ=zdSPnS4Eez)3 z0@3k7(ojee9_|D!*aMh6FA)r#&+<+b-fBdslr z(LcNF3`}+l0G$OCJ_v0GC0Ub6Tq4GiWp7RaP)|!J1CZ4?tc{h9wib&B!@xa3fL}_2 zd%+OE3`3gYu-<606P9l2!|?uTV>9{Q&c;-G6E@I%dxj+&M)GIE&{P6J!;^+rff!ca6ch(W!m-$n$N&P>#)I@@+F&6%aF8{IXGNko zaFIBUF93ZaDNfdoj(iNp!4HR_TZ1Xqa1PHH1q}ppz484(HUu4ztrd@M<_o4%93VJ8 z(aD<(wL<>%Mp?FABm@)z>@aAvICgkIe*nvYz5?!OQB{c;TMr*FBmi&DMgthKkBbL~ z%=4q;OwbxDOVf{QVaCRhX)HF5LZ$lrTzbA%P@0oJ148jrbpZBS**Z84I)hHcSU8(n z!*L#HPiH&0nJw2!3rE$$I03kpvkeph7^dDDP$-{?WVQXSaWwG`W*jGNM{h4K&Cfaz z2eAeayo zSsG4s3OASNK6KvjE77L^WSn&a6Z zrlSoYqr4e-O|qxCGZO0pcA+EfIS5m;00e|+L8Q#RKfB^O<@PV=O13+e$69VIe@*zQ(W_}hv+I9pdV=Dp&Wd?$I z)4)1Ey3EFM?a|hNoHL{Ip)`)Mx0MgoQHz72S@NC8epElWjx`x?@9C?F=4m;zFko#S z036vuEvR%H%L9k;60j<4)2&uC81TY0}>5bP2AnG8k134@TfL73n1WU9B+slio z0RsC3;s6I6f(ORf#?;J-stR>5!vq5JK?fE1qszZd^FPk;U#I?W zWxiCb9ZCa)Wufpqz@n!4YokacV6%d3;q67Dm{|uP0BGLyClAppJ_wEMOIFnYkZLp+ zt$`zPaGrpv%A<1}sU|LH8r{nZWNb_Ul%ugK5yWG`RaLPJ5)8~V*C4Mzs(~N}z%p|F z$-_V$Jlfk2g!Hqrz}rDkzQFD(#*`B1&ty3|labCyOCa|R$lBgR%a3p72m;JoJ1>j{ zj!5L&QBW2zD&CO}27%3OY-k+YpEfoJFoqb`R5;kz*u=)hS<{JYLo^1I6cTVXa7;U2 zK$LT-Xu#`bS$e~8fbC4hI6}Rca2TKKWI=G^0agaeNt59b=TK=@Vg=&8IsSkj$wu?-0epyr_rY_34^VMj1Yt!R^6(f7f-}?)u*3iZ(jVrDb)-1E zIGdv#ep(I$YfDoGk6>p3RJH-0Y}o`+J*uK0!RbwfsGhTz)brw$^=hmEe&J9 z5HR&H^&kQky%V2K281^r*bani+u4ILI2!`qibwV}4`AcG=?>l&0UX~z6HhNF)(?qv z1W-$uHPzh9!b#(&NAKrH^TYuJEiAA^V`DO~O~Ru5Q40Sy75^Za|5z4ij6fj#%-HsT za`@3g4AdL&9soL+5D|f9PIv&dq*E|_ zzB$!K3t+-d-mYMQWCGoWVhgghwgNyP0MfJII-+f;fbU3$LWwR;8dQ5C-jr-jBGNT<2qY{V z=|a=?f|=1+fTI1QjQr8%-$mx1{-T0(ybt~!iuW_l?jIxZ^g%0ehjiMX!w!Fz{y7@) zXJq4FhA--qNxHv8`Vp8+I*mkFiAV-9ec7C!@yYt2-)sEdf|V#mTU%W-Uy}byJTSt{ z04Pu=)5y9cf;LPGOwv+=!8M6$z-2NJH6k3QqXyH05hyS)i2^3Wf3)Y%um2Ni1OxDa z7$nLsq-lSX^dDaTC(KIorb|GiB>tmrS)$-j5|*FpO`)c^4MUs3n{e`4;> z4*Ug;-#_)+#(OiU{p8wl~`{o%n%|5w%v;MV`S^w$S}H%|9Qn7J;2w6b0d z*aTppSNIQvp8s0*dy{{U)BRcD|6zPr=>OsMfBlUBPYUCoqqxD4U!%Bxw&X{xe@%q# zOXUO*SQKLqV3GevA=E#~{`tLsO`qgRV0ci-2B2SA_iMpFzVg?y?*AV{_n%3;e;WCJ z?BYsi{xB&2CIJ5rwf;4cKM#TNkMl48KKlLt)1mu0pnqPBzpm_mUcCCic?(@M%a`m= zqOc4wzDy>8A!kcrv4Ge=IU5^OIe#{d;UQ;iVhU8$2mPt`?<`u84Zz6^U0XsR8VFo> zr12^Lxm*U$Tj&NDfT23-FhHHb!7DTNXW8!`_&J0trJe>_aCPuXdP$vMD|-G}>;PlE zQqiA5;~1cQfg=d)Un~A7`km+xuKYu@|3K>3PT8$=iERM1>z7{sQt$^tzdW!auyHgt zjR=@TECPcK9891vNIU}z0-a6y?d^Y2>EBU9aVab&3+TzeB=viZf2M<`14l;e2tGg` zet-Du0{`#;Fr!cj{&bGq&(F*K`2LR}`D3L&vFD%YtQ;{h_~rP=Ps%?^{_&;% z3x@p_Pkxg5!&$Z;gZHx(I4K95(gXc=Oz+R*`m^Ex{ojA==zsqU0LA}(u+)3Z&m)^?E0^IGatZt)$%v^hgR^qp zTy7YL@OKr!|cf7pKmfkM(JYo5$Hc!DlUhc<8nPyDp`Fq@D>48%ts0+xL7(lz)iev*j(Csf$m<_& z8;D9$nfO}K{D5YTinHu&Z-0{}HbhdJ+Hl;v(`71=_6bO#v;CF+-EHsBwK>Kjie!zX zvX!@;Y)h}fk6Q`F*Ee3j(EZHAc7MUk$TX{s40#A8FyySmVsFx&&E?OJ8LDd;8mZvB z0$+VoI=nF~U(}kXteDM;-t!sXoag9_s-zodtzAL@$3|f*U<#>ht9xA>H{se_j zXchgQv>a?Y24V259q9?Xm|U5ym4<$F=8d z7RL5yblk7Hh=)D{O$|ACJ4xD6G|I2LFv{Q`@W-|93~@dc&pBxCJK0&FVe<0ac-o1g zzV(VZZ-miByXCgMryADC8Lx(FHS&sTh@GqajL-JTQDjZw@mj^=S?ZMm7e^4oV|R|W z1nLTeUw!@pwi%)k@R8RZY|47`(6;_=RIiA)ir>mg5@*6lBn(5UWX$6CK4JH+Q+ zHT7-QKBrifixLiAKHl4MWA+I{XK=02!_#hJ3$>)Nh@ix>IP=jpuN+PST3kSys83UdR`e z!i>{HW{~rNNd1?q8AOe=9c4xPy7pg_r^wdaC!vgJQmBE^?VOe^9r7M6)!NnTMzBvb zd!O#a=FLT)7uquwKaL{rzXh%Byd!00fAU0z#=MwY_v2xvI zqgU&XPo)Y@ZZWW!OsR@{oQS&nvgyJ*J$ktL{aZ!bgkLZdk42_karag!y#A5w?zdBU z0XwH5!D@7`coWp>qkJBBL3)jf#ni#QLT|^bS!YC@c>;E&52g2**lgLt3QMlA9H#=y zsMDx$M@~6*LhfR-<&NVriVt9W&Q^`q_jE!xX}a#NHCwk?MCC+i*;Qujp#sr~tv%}n z%OpeAeTRz`qw!Zf<=zkI^hgRTs$CGIwE5hT$b74?sYHgTN#5RA;+bH4{*^9LC&oE= zkDaWTSoY$0>9wm|wMMGr?T1;Lz)7FdBJOu$13%T&9e%rz{$@XM_c$x#hG%-ImU@8z z40^r!gTma4Z&Yw|81a+lmibc-GK|e@?eDh7T-ya-7kA*}j+fY-Y03MptJGs}b1sqO z7zE?~TW&@=W@pG>r4Kd<^d8X@T`O8XtQx5vFc%jmdM#UxyjND}h4Sk*gNPA?pUL;n zU21#X&-rmw^x9D2mEQ;%vWr*4@g9P8n&%!Dq<*_$wC?j|VLNASss4goFS>Nq?OFr) zlqXpi#95NL5+>z&iz)Md=B~FGYECxynd)nPWI~YVEi23GcSv_#S7df+w!0dJUyqHr z{?bt5aDzZX!8Szq1D^KXJ1y%d_dS=Vr(vUz*caHXvt@N67@03?_)kx5d&uC7Iy|wi zxN>yg7PB2ykLA9KOh~t=_qXbV#~kM(BK1;V#}hsTJ@52eS1&VpirC$GUE-%VPM&tHyce3zlGob_PNAQF#o*iEScUqchz zpysal22LZorkk-Yy0m+%Ax+k!jCF@3h;I?RwfENHqN!Jm{RF4Jg^a`#;#P{1F3lSA zt>};IH=LZ951ZOt8geImV!Ulaw@>Y`X}sFSSH3%frg% zn;=h-Z3f;jH?zPp^1l0##Lnq3J^a?evjQ&$4MO4whtG-FmTd1l*tXtX22mH+C)MQU z^)!a-IPl>7$i&HKGqP8AKNxMGmWUd^!5xuYq{xM}I^m06>&gZ@v5^^1b1&;toGhV&SVZOpZCd zP3LSny>4m#tW3ksX`8<5f$3Wwz1V)?I7={qANTmQ*qgN4^P~kCm`R%7YFQCYC5j|m z;oj`)84vS_?YpG$6n&zrXGttzjCrFp!y*xZE7tF9V}oGGyvqun1q zb#~nhd&?+#ZWtB`-X8RBcx+(Per4ZXK}#;$Sw@ZTT8=!ww>7XkXO+4_-lTQ5P>`$@ zF86G#hD)!HYN~ujZw4+NNJe~%%YXP6UeX^Dt^vXG9105rq*4JgFHuOoWD3Xvh{oc` zK|#OdVgnMf*~x)_CSXGmy&O3{zq=`(eGss{2Kv_R7Xbyzeq!o~0|=p_-q2{t1C1x% z7<0MvjBh+eBe1rqkhS2sbN5qkpS-pm1G%hhr~@zE>ta~_$kuRgY3AOJ=1p?dtGu>K z_5MX;SDdb2^0fZu(yd&=`QLHrep33|F5Q223n(1^Czmb(Z_Dn1iC2x|8|{1Iat4D? ztfEVSctlx8j7LLa5&N>0=enr9+g6w0j(%xZUwZ4as#i45w6o74QFzt#3s262x%sh4 zU+Lm3gQLgbmHhG)#p1e;od$ir0mV9t-#WMwlN~eUGFxfCqP~2O@G<(teWhYOPg%&q z4n`(Tr$3I1US9GcKgn~{%Y5{G zr!AZdc_NB=HUBs|A`W{m%kipLR`kcC8$tC&4SVgl0|z|%uWS~JebLd?W-(ADo1AEo zCnS=zCGctN)2;{w`g-=)tc8~LBY_IzE*+K-Tr_=QXjW*)b~SOQuA2LSiQ`X-u3}wr z1GOm(Tk86v)fxMJq8lq9v(ifg6Yx71x@U&dMQ?tt)RzupybCToduVRvMBOmnb7^&f zSaD!!|MLlt`k*h*J1Y$d%7)fd7Wd1m12yI9AE>Dx4cr)ait|}@?VVJa$qO+3hGK1x zn?3^uQa5E547{B-PR~y)YKE;!fot7A2)>Wm_eE_X0=4!c%%|jVtzE9c{CP3gWBF#b z-R(M$6C8AOoMj>mva+Ik-@o0artxwj$T0j>!gsW8WqV+dV3P%z`dJqDo@rp{7SVWA zWBzbWNyp8*G8Q;>|Ln2ur(v^S4tHr3^IZMHp4sCP9C})V@7tI+97=UNUH!FUE^kVF zaHkJn9dc=l&L!np?Ny2=H&IXEq$({tQFlGEUHD%b6L#gcyt~oiYJMG2k||DWy8QCV z+OQm%D@_SUtz$~bGU|GhS*wDkx{U-wV&3WIZMZfqvC4kFAQ+-o_2uK-JBYA5eJA9_ zezE@7)u)$F`(FRP_RVteDZRroJK39x4yi0h#e_-gb@e@dENz)9crLw|V4@ytX{7n? zZe4%1>l(eZ13R}Ixp4JiMRd~9RL|x`KRLc(w0I=r+qRG5$Oso>jTcFCLvO1t2!o48 zIn#+*qgUn$(;>wb@jcrpN`_*M{J!KU`LDPawH1+7p?S-Kz{e zZdK7qr$WhD`?rRlg6-B7Z&egtT`HmTy5unZNM(j^!JMzsl$g>j&7iGzr|WXUUkvFb z=CjjZ-RX$cB5&H8Vyc+-zV^`_?@v23vU9^sE_1HDeL}Gve0X;(%Sy#y@?f{ziIPqs zDgMXi>}^+qh*^j|!|ZT{LsYH9$}-M#FGY2{kFNTb`0}267IM9~Q8rZUTF<=5x6xJ? z*oBvv=;2Jzxx3?4GbWQYy?biqyR*XwKoZZ3wF5Odh-{bb%}`$H|dYTSK&tw>&5F=KU0*@#*F8F7QO(z!y*jF4rzTNG7Ui zc6DK?yO4cZxL#xY}S(MMHn%`uP3srdXw7yVi+Xh98Lo=q+*w@+@6Lzgz3e|s4 zyGA*3Jjl}-nqVBPJEWYWb$U-Qe}jhQoo6kj!3w^Z@ES=YvI*A6%cuvLdeJ2IW@@xt zx^ouBR4{+KZO6xyhShw5^XXICA|UR=fG;wQ#-zgCS!6>a!fwYk^TCBdxNFPU+4a#U zAOjDdDc0(B`j-1vC)|tieV?&rkkhzcJ?|SQr4GSmyL)TS34K(bi7{y^ zblz8J@pfNdU+MaHhm~g4ee#?AA7y7`<_y>BX7_1IbPyCI3J{9EeaZt*HjZ6^WL?sD zvZy4VduB+#B&}<`SAD}L+QK&8Ocah{6lyY9R-kZ%bMvCj%pX>CsOeG>H|;Igvc_8x&@35+CRm!Hu^_o?WE z#qCODL5IY9fKLp$-?VU6rwyhpHJ} zez&G2Hlo2*RAA>7T-@M+%txMq^w@|e#Mbch8I!G`4RZRDYXroV8_CHt^-AkF-N`Yi z(eVX*I*)%_+UxA2M)!7$Q4@@P8#<&ES03gvER8S^Gssp=XD7ti%wN}WET(n#3mpuFL^h%3WfnVZ@R zZ}lFYlggI7@+sn+@j!0a=3JSa2u3CAX;)Uap?MGtzyN3gYU8$gSNmDa0*L% z$F$ftUxzS71`QHnc9_)eqIatPI&M&UPxj;C*lof%n-Ds|8e^5p8o zq!`?aTqG_|B#V8oIh4EE`HX4THuV?ygvAKotfd_zI*v&#PYbTAX&$u5X|8I8kfR%=ol zb@P(&eI3sjA#dUpUVIwzO4f=HSR}W{A$Y6VYVdO>ucgxs^9j2)*ay(L6&qg*JIU;M zd+?oA-WhDkns+Jkxgq-bll-@n$K;+hnJqfJ>|IwGVp?_bbI|UZ>yI8+*g57&hfL`D z?Dnkw)>9ht;H)9%^NI4h+3ZuBkL|dYA)`<q67g@M^ZYpcHT3{YM$KGxcBC2zx1Xv!zbvbDrh7kIx=K^_pgnqTbB! zhjK1BV>YZCv=GZ0PPDpowYEw~Tizrm7-JItaqXitnfO_y@EwJ1IgQ~b%CCllb3KlK z9J79Pz{qj#{;ECV^Jm=lkJN9z*gt<5ny&HHfsvqodLT6GJM5Tq@6c*(NFlmb<97b| zg6>18jC4Wf$7|UO!nz%&W5>UiOf_IdL)It>9C}-;Dq9krQ8P;+3r?~I&8#BNe|X^8 zd#2?IPX3;L!cf$(9h09k;MLZk_+`U_w0mbeN^z{XP^)OqIll!FUdgEEzEimit9E*1 zcoq|bnX|0nLFUyWw~}M0t=3l`hQmF?i7-@JKniPrC1$+~^3~Kx|Ok`3ZY- z`l$!LWW|OO3LAt%+O-e9>KEP#Hau=Fa)K(W7g+NmMTWB(>3<_kN?E9ll<)yHwfJ~E z&n{`#XZ8Be0!m-$ewkZ8qG9FRL@e*~n9xG2?tz9~wO?z)8r`uzQ6j*fHcygDy``v=^z!QN^O+ev@*eX08;mw@8{89DH<-%vJxsfC zlxbv&{Ajd((R=iQH19|h9E30sj=P;vwYU_A?@FG0@^Ef$XM%-S?F*-k_qgY~mi9bO z9M^aSy)N<2SIKVN?sIx>LDAmuHw2ikg#NR8ujBLf&s-cX_-c2rpyArRvh8caj^(9w zHmZp_t>3K9NDCb;%{sa}{2|hL6^MGFJV2qk zbHhs{f4AGBNL9pIccHVsAU#5BOtzXKZsFYyp9Vc%g!0->qkSk+89Sl2u-9prWWA$D zd;Qj}N$Q&$pr^WNZ*M`rWWhb0|IzaDJi41?p4^ekfr+Qs3Vj|8HXDOxYeSXP z&HVQ~mX|et4f$TPFRxZE=6ty4qEGp0gD-76V!olylD`}I~}C1X;tR5v`TK|+uVst`Rfb1pJS)$lpll0;4|~< z9DANg-WT@n$t%ZpbZO|6d|>$}h{`a0$nW0dcS!`So2Bj?E1l83n<&ND+`29F`nef7 zDP<9pUKwF|_h-`CcH4C8AN2A1qSjQ~mB-%+HxASBeh1bZMS`5J+N?(nsG>?@HV=SW zAlBfm$@m=a_^oT3gG9C>k1yUbWxd>&FzYm5S`@0D)t;~7G^~{%yW85#3S`BNC(_Da zCa(I9I9Jf@&`%$PWpRu)TRwW^QHrO5k#*Pv7nieFE~hmO-I;o$BI?lzF?CtoazI=j zW)uUHkFGnlt5fTmr+MCh?0EC}+L6>s={i&C-UjcsYX1DZP74=wJ-~SN2x@Y&~jPqi6?pyJu-xC$={ZD#)ie1mE2OYY0VV~x!4e$MrVt3yf?ee+b zR#9o#3b`ROUXXtklqfH?`ZRS!EF@?1y9qDTxsju~taXN6r<``AbQfVOE8enqz^dyn zkK8)hBNmTdT`FR5`z{9o)x|3}w;X=(oX>YKB8>-g@wTkb!<*_U!;7d^?kw)rRm3&iLWbhiK?Z2 zg61ZB4#W+J57@nB7pA6UqzSi(uHFkBPs!a9;bHjT`}}OcZhUd^y`${AN0(sfkBaL? z84q4h?xI?MEzauYPMK_J*9=xT(FjHUCNR#i()DH|Ud<Ab09ugf!3D?(f|M!VwNGiUiR0eF+#ntChi;GwA(deTd*ueF@9Yuc&5uC1!Nem;pKcH28- zL9@NG1c8!(amkUdy_Q!3Y43##m5E`Ol9M9~vJ4NV9QylW?LQmZ@)ru9kxG}Q6$Ogrh=}(X|k*1BBUgcB6*WVxW3Cg+c zA!l|(v#OkX!Z^?woGs@@DQSnO@zc8W()c8sl=n)fqU$vsPuJr$$3s@%B3%q)I4WFy zYgvh&_siNUCh+v)jVpDthlTFm>Z@yP?;R_;dCurUBGS(~sZ z73ED8Y~#v6*Nr}WG{g6;=3ND^>kk(WSo@Emk%=7#3g37(d|w(+84}Z&-(}%WjcPh( zcr@ivL#LsK;qq$7Z?5`7#g*=a3%%YdALy0Npt!xsl*Hu138GQm5@YSqr-`B7*b1xu z+@@~m(-Od`=n`=XX}5j3O~?Z`6Z1y&q-eOSsXtC?O=MX~DVoP~QFuR{cFpwbj)x{k zOAFa32N+ro=>}~ke-~YrzxYA2{!x4FCIcEs@*XGTo#}1=d)C`!(9$AF4_Dpg$Lu2J zNK4F?OV4LE**z(1@S7Q>lz+p$3iD8SFqG-JjZkpjXS6Q6bK2>eySc6V&d{_2_;F=+ zRGftlscGtlO~N(RL!p-8iV(>`*$>PvWMbK~gQuFchhF*+!U~*(MAxWo?dEkgoY_KX z5^Nce+9RSfBxa*|R&%i01}C?AKIe_YyY}F|mps-j(mw#9If-SuVP5)IE{Jq8G)-i%?tf2?2zH?j5;qhf!>_Lq!b}} zYvawt{0(_5Q$4%v0iNCb&MS4xCk8^AcU`Ei`m#e{ld!|J=OQ7S8k?OjecMtG20@c0 zZB8bO-yPRE$r<;v!`SYS=3ko1sXnixJC=AHqg|jMW;Lj&g85*yaJlG?OCS5SLHW@X zt<4n=9&9N^dhg3A+qoqnP&DoFsl1qubWCM@MRDk8mqo<|)u;P~bx?ittys}f-3oE% z-S+FQ-+8RE^lXB#W$wG6zxn932pI^yNQR#3o$ z4PK{(j#a+dn}6UFC)-<7voILqeo&U}F@EM<^yH;|Z@(?`i`;bD-yih-U)*U%adJ##{oANyfm~k;P<@jmHg!3ao z?W3J{IbX|4)ek5gN);?g-GMOQJ|Zh5{;=#`s|G!x@vWZJ<30Wf3dYATIM;kCBrglP zUfz}O_CaZ&`hhVcT17Xt;^}nf`mcr?PDBrJ&t$TCUTabz@5xu%lZpG4zbMx=D>yTx z7cS34$iFL_zQfs-x8GF+D7Y-XJ<{RF=LK#?X2W?!f zrt|6Q`}2`C4XfUfZ|H}Rb6c;yiMh7wxsko|B?$%M)f+QuZPFp)6umWGhPgLC^*!Rxsri8ppH4dJFfy*ZKT%xg>T zYdR~ads5S_y?Zp&ttCY^{#MKZ@V%2Fag_Y^UHgdHq3QegKD12YMrP>dq+2}AQzhoN zN5>UxzuC`NwO;HxVLI?8`(>JS;&Yie109WrS5GpMQhJ_~;Pnpmk2&p!)w&?Z%K|bN zDP8F`gz4$gJ|YJcZU{2!!U!VCp<|cDH6;$=UJACYvGZ z`Z-?aA+#V=+Bqr^l_`X_lHP%?S<=7kQX!OiXNh+y@_2IYq%TjkJ+E7t6aO9^5-o_S z+*q*#2OSPMurmuwzg5}Tn>K`bR!Wk;FOYd#|Dodg;P6U0(Qj!NMigR3iM__OS%=vd z7NH)s@7Jz#AOCtiPsH?g?Zs;e!hLNcM|Teo>=}mz28BK}0=Ykb}1j4Gp2eyvL_hM=uw?{1{_><&2r&>|RNU3xI?PGbr6& zCxMOO?^vgpmF8=`2lD}A#Q^IxJ0$@UYyirDSIBvA+DBnL%QGR~?%{l!J#X;MK z5lm8`ZSW!V#hN(bqh}b$)A3i=9ZXIsLq>6gq-uwXCZF(IDL~#J^J`ai$BZ^?DOHHz ziuk2DH)}nYHm*ene3uh7R|>rhP8!W}4vqC1?*VwJbd-Q25zmFlKBT)mcOegjEX zFZi_MH?ujn5_;;D*x8_235F(B!(@aY4?Fh6RdB5&VpG%qXo z!hFF@KEY~+8XM?fU9CY^mBQ2u_Eb7!U1$69qLurE+c;JwwljBO z_jez1C)Z0ho2hFaCAs=I@}q42(WL|MYy-*BEs)R(!*=?PEUSR41KYSau8uzllX$iE zTRtPE13{j%R6Q5E{oZ7O^%w&L%};(g=DF>a$%Zp=0}QkBcmCq9WskaUy6pRONFicV z&ZIS5{61$fMK;)e$)dXGsfW2U*HJH0#G`aGm0Q#HrKZcWbmv5TN0k^acjC^mF>F} z!9Mh7*ywhYa^kYSee;_^~MhCG}RnzSxQ!~laP+4D*l09oauNAIJ*Z|%rA`%DL75c=*=bIbhJJY;tPs;xmEp)O?gj*)y;yY zbk6M8XP4U?3ADe2yVRJE^V3lq-Dlr1dimR$#1`8(2%FLr7_#X=Xf}S=(%Z_`$Jbm^ zGfyVzb+Sb=ylawYsqvJ&C*lq>5fMe9B|cd`j5{qXnV_?B0gO1s>oVy>#>2kfy^-th{zPxuCOTZ9*Kj4eRLo`4hwBbm(ir&&*|G zoZ9B~hWcA6a>p)>4r%3juYNzVfc5ySu(!YG)83K**x1<1j;q8@1qCr1y{?d@G_`c_ z5Q2EeaNwiu?6$L;$XZS3*{-=wnP!>Swjw_7?r9Voyr!cX9ath~O>M2*B@_Pev7FpJOboHv_OqNk?18<6RvVD%gtHkX=T2+1=bzP~+MISb zhN$04@Yl|{_f@zD;i4ve!|9NJ;Och>W>6$TQM$-_({;=y$u-!FmZV8^lzdvu) zsRa^O#O^AidsO9N)WwL2s#5O+)`UIGUnO!u&?$Rae*bYGL4UCNheY@`5&k@@ zV&QXl#x6X5?ZVj7m-g=1hP>f+?(*2NW5k<-)^fW061xo4F*j40ykEN5ahmF1e9^h>#+4UKHHH%}HF6wpmxnpBAI#pk z`yjc$S@M!$?)@(ft;;3Usg1>;zWkf`nT5{Fdb)+)?PC1nNgH=K<}*B=eMC_643)qmh!Uf#!y>JRRgt_MpW zc?odSoO!n%fnp!KRLyKH%?-KqiW8<5JVq+;AC#1H%!s-+BX#3Y>V2pF>Y2GW{K1{} z5gPA2k8?f7i2lzbTc{1vhT~|o;coitGp<=nb_o7>m zuD;ph>h8AJb0^;(VBhP1kHtzy#H_K~({1?0K>5}frahsmOjdw(IW0#~eOH?MVk{Dk zN*OiPJZzUN(vUd%5&cC_m$HRJGn`V7JTO|qWBt1$v_#ovHz&ah@! z)4IEzr^FX8w8ljo%!G^nprJ-Ur1CA)7c^mnr8JUBf)A;Gdq=1x#ugCp}*?f7hWc|NP_O;Md!^yu8J&UM zbj53@XNKN?j8?f|INz*kF8yXnFfqOhTrZQiL2z{U3&sOZbV`!XXd6N(*O*Uil1Ic; zpAaPU{G<4SZ)A*N)J>Wt-fx4 zz24?63O!jJZo0)G?%Uq`Q}sj7mCxI?=S^41-r=CluO?xk++97ShGDHeXV+XicX|sU z>{14_<$Zk88RN$HQa#@+Sc4uhp#PTS34?p2F6u9 zR=f#TyYzopd&{Ucpl(f=(%?>UYk&lIFIIvFhXxByDQ-oJ7J@^u;1n(Ht}X7+BEgCk zr+9H%pz!hDJKvpmX1#M~-C6fva=I-vIDhbmqx$RbFVAn=J*LO{i#*VTe`1L@o`+Gl$Og{+kbWQ!kL!}{U(*>3 z_pQXVTliDP!uio7?;fKy4f$$U!B_$G6uWG!?2)>xR)#eA!^sQ1{lIa*_YtGK)l2Nf z?P@6}Jq1!MqQTGR<4UZ-3UR??(n??L551OKpN3g;R$5rC?Px8LaN^_E= zg{!Z$I?vP;$!ZS6SryAlhXG>#v4dv*jb9@;bUMgL2PW z;zPv-VhQP1iG1V9=YL92S!u}i%_?mU{dO)1SAMbV`%|^X$R6wCEWNe;hO58d$o6l< zkPvUkpn)`?{EdiC;KpA4k{xidfE+_^=hbJ~q|()Vnj4_Q?x~iYhN+v=v#B^+e_rQ# zZPJs9oLzKT&&C&r*L?PvG`vSbJ)h~1rqV)asb6#v?$6+DvX}xJYd6c=URSRU{mhV( zVc4^GETVQyoiYa%3aACZb)?kXL-0G|%!Sbd@~?lI2NIy?B^(tgTc)@sIVwy&^e4Hy#*TgM@6NkxNia(ayG;~7Z5DTI<)RC8zL^)DS#V#kXm zW*@O7?I)`69`m^?@JO2tGIZL<`XHMq#!fTFaw%E1FjtSr(Bu2BIS$>eDw>0bV0KN%W*v%uitD5* z_LjLHhV+qpTgkiLCX%0h)+u16&?j>iRFZ>NtRs^FL^>Y&1Z{156N8W5jbFB0Pa#_yp6; zPaYa6mFOC=-HEsg-u7r)y|&k5P5#L9wF%LDsvHn@+#Hu=$yVP32HSQfv-IM1Iwp)3 zHDpvPUq)F24@LXLdP=Oy67IWC-dE%o`U`Q|Br5V&^zB?5j^XHn-y((F^``@PI{&tI z!cCp7)058~h|;^*y*KqqeO-ylBf9%ss51vyy{MuAtA0MJi7iOSu5&@Hw_^8AI$WqDVh>4dnlg&>e`(StM_aRE>Yl6@R#3ZDcUV#6@#|KNquUq`1Pff zLYw@ji-c>YqB;xOOHBs0+VSTtAOCpjJ#mduEg1x&D9S#gHAfY)o(bmg1#HnSFFY+( zspxtl+LLsdwdzRwM=j-RIVzst)x%RnwW+vL;cxcYlYt00y%A02IGZks{`rlTXz&T3 z$OpV#%}^1XUx+m|gdqpl-OiYZJz@+&U#jW@?dGQ~S_6;dA1Pb|MIPUw35^rc7$jC`7e-(5tNX0L?#fbHY!Ms2oDmj(h9|#iP@1Yi zb9`*Qmcq@HXS61NEWj`F-amQsbS;gk`N&#VmayFElPhsvgXo)=$3w7JGb3%jn-tZ5 z#g*`LegrS34V`umZ|^9an@1@2Khv82McO`Vo#~fGED}T6Bu;!s<##_%uVl%b%yZko z1}TckS>R^Qr!HO4l;&Jwl?g@D;Q;U02|60xzj^zk$X!8NYL%W2KabZtV;DR!UqThU z@MpkattTJwg0SmluYz<`-OHE8&jqWuuMwgr9lFmZo;|Yh1Ac$@J9h>Dj5%NWRH<$p z72$Dl@!Qv~x|H)QtSXcT{ahT2!tFpL(RibOzwq+ecXFS^(7Lnqfm(;x>i2<>=^Bdo zH)~f%Cto?J`I)s-?f0>yVBFuS_g~9_K4th?m3*549`dV}kj=vj z`paufwENa4xLQ_|i7FYMSk({)#9D(XP2Ae#`@a0)TYA*?twzP!Ty*0*>=jG<`&=po zJwY>I%5C^piSB>iLjS*~<$V7+^h5t&rsXpK<%SgpT=KPlW$Zt-*Zc?Qga6;T1O7W$ z2>u;E@?UOQ`8NyxuUWZ}porN2?C(*I`k-&*ve`-=nEOaWfUalZ2Qssy4KT1OPKUt1 z9$O|2DG7V5j$JwZ*aMIXJ3#spIHfUA%-Hai4A2Sp8|y~r*wb$>N+m3oa*1&cE` zO$U7|1K63DGyN>hzj)K0ZszXlUL!d!UM5!A{alErY@^Uq@m*}Nq}`s}c3)x^W<)6c zO|sRVr9erF(YU%eLfoN~Z(RSSPm`X|-mW!2$7JJ$RU7Z3~b)?f`8P|VZ6?HSha^w?(=1p#ixh-S8=tG;z zV)GlEVV$5zxOtJb-@RrZX<dg<3S`nfkw~^b0~kphAgR*wT~@ktXS^Qyb2`8J)Q%|X zr-xb63JoYU2#}V@fA^h@r({T(f2qON>ocdxaeT}Y*A;EerLP3)5 zZAPvf#;BWP%AuP{P&R`%c4f*e4GsEy@vT61e`I@&4g;9?@6-B#bz4o`*DrlOnL4}; zx@{TUCQq#JoW$1@nv|uMvR%gt{n3;tXnQ409A>yM=lSDQPhkgBm_F)v@cX+RQ-Ys- zZ|N9)m5YTQ4ihWiLq$}F2LJv&oYlz+a&H^U#kyA+cr%#i^8}X%VPdhs%+9^h`{P++ z@@t8Ldu8JEXa2o{GUbNS728iH>raR3i$H!@{Go5_9sf-b9Ye#ez^+P}quKS) z_yxV!yY|q9-%Q~4c7KV`bSFNIka&V^A-!7nPvmQYp{q`gf^lv~ekO$&B8ElnS}pd! zN%~T+XB+-x>aq%t>e&h`efiOXfS+3(JQT(?YG;ZysN?yMc%RqOg_H6nQzhIVp&9ay zaAXU5;~InI!j(JszR6_GQ%p;?`}?PGoiB7qqB~DlRMco@3-z3HEG2SwwYFF_FMUrN zw28P9l9%A0aX3rEQ`A0Ty05HBqFiW?I){ZO0^s?Q=uS%;@-KT$jp?62k0 zml!q=GCWy=cnH<0*;HI;UNlh-;yaAjQVXpkZ8 z?V-Gu%*kk*`72*y691f`@%HQpD)G znB@?_TNk-&Go6a>@_{d86^Ho=k3$zDZ*;T%MY#BTHf`F7XhVaBJIB{Vu@cN_iFQi7 zv9zrxBJY4|ID?qW8L?eG68^~D zMdje%3$nh`MFbKZBM&R+S(5R(Q70THb6P+`ht!yF{x}#~TO9!r8$Ty2{)VobpR0kb zl%D4|2}OjOQP=>o%cNAUv{Jqoek$P6B#s^JG;dOT`$oHa*D*VNjUecjWbicPl~b`H zYv-RG5iW+p#F9l0m-V@@hegcZw4o!m9>sMGm*)N_?z>^+QW99kJTC%aDNd)#_!YO9 zXk>1{8s{;PdS>)67rJ71nc7(L5 zg<#kAe6y9Ca>xDL_QTJNd_rhDGKf~#ly;}Y*~@q^uwC)y=b(!_CX9EGK~!|h+i@{X zq?AhZT#uJ;tHCF;y81XKZ_m+7n*NY&*8?=J@qYIknUYQT;7|6u_NG;%YiyxM=@{oH zzb9V@JDFTX-Z~IL?AMgnhr|eJxjz$93QUFP?CP`i}gf z|1+PVVmV*m)xuq>h5ox9B#dQu36FaWM+6cGe83)eCw%+lZ)&Cfr05NLSWPo_nvneR!08BM2yig!zERl4Cfh|0t_ z1MfK@He~tDIW^Hz`bZ&@^^oj z<-dt#34b#b7t4dmPLpS=-8j&it}qqmI6}dixWd=&O}hRk)@uYr?15rbiWu+RRl#h@-)E}X-ohh= zvVktiIu? z3&rCAoJwJN5zDgMQGsG}Cb`6koZsi4Qd^lXG)_03e?L2*dX`PmfgRa4Kor!NkZ(r+ z>Pmep#Lu9dNR<^Y;QLeFm!asd(+^KL{553?FLkHDXw1J2Ezs9z8tgvbe^_RR&wb(? zJ}ctoDu#wO%nY?E=b2_Hxs#S|)xb?uJ6b4dw*9q7Pj62xcLv6axqAKA5Ejt>bp;}6 z=Wg%cX)IlRyP{M98Z6!io@dCy>8<#!)raoncvL^AfVVb!e*EX;n`uLV&#J6r7d#8P zLfWw?@M(~p8(je1U#S}(LG(YKU7rjYR^+xQiA$*+>rA=zM~V@VPI#k4zWiP-#VUsvcPr)mC+1!WPDR_rh^@& zWI(A^r5sjcIK|T#?3L$8KtkvgTzPZu@3A;r8KuXPQ^I5gwZP6zZ1 zIlrQX8JEV_Dy^HiZ#a($>vd~*f$U_P>ucO`oT~NQdPrl*PhL>2+G5VW1`{i`M^K<0 zRbDV!l@0Dc`6ivI<&3N2mW3kdNbq#{3m$O&s=Fe}ywXt;Eff+o5&b9OW{} zKg9jxK*IlS$GcH$!grJGzVs2CiwaHcnu_*$u&o_oE~{g zP~K%QaZ0QOvW`Hf_bTL*6{oulP2h10Ttgui`m5l`yo0=J=pD7thbQoGbbBC_cjfge zQGbsm6mfO&>go`KQYK4nMo7A8WH&>y`jmLei7GSn z_41DH+|-4i%9^&9lUdO>lIyWvQ(unEq)`3_%`+ja#^kK=LfkccEio`%>_UnHgN>Ys zX>q8@Pxge5qcel4U;8j=A!|ZE#yNq9RxY{dZN!~y+iMH#>a79~uSPHnc?^h3X^C`^K>}Z5-Lou8#I(p9%GHdY@#Gn33 zZYV@Y%dWCQ5wkZ~Opn+{_IEy4Y5%-HRP@4?rQ2Yb>-xC>zfcd_o%hC64CCyRf*)rO zXu50HfqB{x;9NGt-X`W)EA#2^8VZJEjb1OUEt`m0C1tks*0&MEIAN`LA6C-Ky)`o4 z9hl|X1lxZb3QZ)#rZSCEkS`q0{mBh^nsfbm{2sJ@Ct_NJ00{Fte3g!1JHuHbqZNl= zo9btp9jZJDF7#5?=%bHuw;~|X7iW44b^m}teS_2SyOlbC(i{!C@O}4RReSz(Dvfdo5{~fidf52z{Ul1ez)F<_yAVviF z1^?ID)FYotD@n(#{;rrBbTcDS@p(#!s^O5KP15%!z0n9UEW#-2UZy8J5P0ImpC|Vt zc6n2=@#}0x<)|PacHU>RuE4)HytyJaSm*niL^Y;eM!8iJ3|BSBsotV3Hk`lvroLUC zd3}cGu2hPSU0(V6T}^ht+pZ59l&GX&SF6(%@3{*(ctn8xTzYfW3>3#yuk@PgvR984M{5ur{+y{lLA znZ)-e2NoP=dU(u4k_Y}F*Z300B^1T;Ev|SA&G+NSv80s=^KlIn#zttikcCaxjDqfpA5JI*QOTiEG2t!6pEPQ=NMG|z9jP; z(ufLDTwx#Ft>vh2GMC#8wvQQ6Nm9=9NMK&oek_i;XLIIZwwiPhM66#-+J=>F=K$%# z=fB-jJ15gw_d)$q2EJYGB56eE@ui;VC6itvbAIW5Gn!pwy^cqp4mJgiFD`QtBFq9p z3pWamx3!iH@w2fg1sv^r)0hOPXd6c&C_B&#o?__4eJW^L?6VRmIGzMXC@?!be?ZAK z*D0AQ;h(j&E}<{m8rDV2S+c9&p4;z?W|^hm(>14#f0_TqlUNXJB{Ied%9t4DA;9c; z-N~^~=e+oQ9&e>3>wVxuHApmQ#JPrbJC=OEf^ch_jtZPPGpIXh#u-czMtvbr@|PU!2F-fP-Owa;)D&_2OB;TOCf zc+KvP8Dry3c4Ixq%v8W;R93Jr{gJpqXtX3$h~431$;1fcu}5Z3z9E-us!Kgex=ipn zRfq}$&GWLBlC56%{S5f~i0E}wZH~}-GDvvhPgh#!o&~aP9YgIv+N7}C+<|Y4v%uOh zrXfOXDJ_k+GkZeaL0*x&iB?frZyeR%L;#w{BE!EQT1h&ECly_r{K@9Hgq{i~T@KCG zaO#qF5FY=uSKm1fYEO6-@Zv*#<|3)ZA`6{@6$A|O_b2JB|Bz7E zDQQ-vC-z`+1*n17;)dgno1xQ#pPV}ceiY~on@0@Hj_hA@rn0*k;P9LO@EE^0@jQz1 z>Ty&1d}){?oP@5vhM4Hamyy8eZq-`wm`Qr+DQa%uVfahz0MgC=#O>-MU9K6M6;bNi z^jUHYK$bSwy9sXM$gUN2y{-W`R&4QP3+f~>S7Fk8w_q3Vfg`W+tF=_&Ea?03UC~BD zok@A)j>iwLl(C^NKT>qTYN>FEdb{zTfV0!m{AWYIqj*sa?#~x?pYfir`YnmGCE^58 zoN4*+W+-02OU^CD`*>vwI2MP*I}qTLZF28?^`IX$PSR8O(Xthnv_1PxD zF5}(FGAYFG;?%!!66oaJzZkJbGvE(3SQzPJ8CUR3deZA86q!0;WHj_*o zH{rIl4FJjbV*u7>cn9X9KR~Y#pF?1)jo5nrzXi#jO%%^G9D#(rw)8;BddLXIt8}^h z7^UdAueVqQOL@38Hol?H+D{JT%f6>3mD8}8HzPhTKGCbsZEc@_CCntZY}`{${ff%h z`3d?ML5H5}W*PmPv_RVZ5~H^>80`y%{RBdOKAmszcNjosJf3%WtIYK)T8T63npc54 zQ7okqse4cxX>(3e7= z9EXZI$}UL%*u;JK^H!mO};vY5?%6*F$@-$bx)?uLTP+ClZ9qDmVA$2md7y$I98ju^uj9cKMt63v&^GBe3 zpo*V@?QMU;ClA)aDT+*Gof57=#ifh>ve&O;fs>k(9he=TV(H?OHbh9!CNB(3=C%`7hm=m&a_4od}mUGw|m-F_dkfp<@Tj%Rn`zcZ$clczw# z1-v5BZvO*HeXTPk#~8YyBUMa1m9Im+B}QoispsXnhkZ!iYhx?j{{vgG390Ip&RChc z&(A5W@3aRCv4(n?b&ifU+YTu;h)(vMh@bvKmZ#TP$P9+dYm8=w&&G4|AG`DHl>TKZ zy%)g=xk!(0dDivg(vesIN{Q>7v{g?}~tSDG_9^ z?7;}MtWq+lENEnrcOQV-t+GCaoP6^Bg%(wkKBt&ng=cT zx}}f$Y3U@bga^lcPMspEywK+eS2riNIGUEGYH&Kh(n;=z?_LIJm490%pEU7*nIZ~} zOI|kN$8Bw*Cv`pg`m&zRT|xq4qa@FmEusziv8+>N$y(foje2|3yL)%(oGw|V)XDRSHGxcF|(R%^Mqq&mE} zqpp54`W1ck;~G0!xhBa(*4;`a2CqW`!MeyOA<=7dV=TYUPy3lNUyEA2&)0nfPUs{~)RA zX+yo_cRSUENz#=$ah^u9yu@RglrT;<8N_D2^auTT^Y>+M1t`WbWlSN4{4vY0D$^@-G`TOMQEwZlhmzW_uQPY%Fyz z{(7=B}pMj?y`}RegT-G=?9-a#6XJjnDmwi&m$2j`~ z(Z=pzKK|{X4ax7MjF;gF+rS}HUN(tS)EI+tbfV#GJdL^GMVSda7Z%9{w{3spmi%#FxkrWCLF~EQ_x%jXY$yQXsk8Jz2nbElASDk zL~oJ*lzSi%bVe6Z*JX;HBad}A_gvucjobpo-8q*UBG{SM$66m;QH4WPSR7}!f^;*_ z{jL3n8X~@x{mIVc1$=Z~4p7`Tb}B`1?FFFP;y8K=yVT8k;;q)q7Khf>vE(8~iM0)N zt_sZa&qH>8awiDIqYm4mrv{*-O|bl)Gec|1+~tD%;sW7cd&C6GgFzYAJ-&R<7ZMHV zXzWK+#;cR2)Akzjn#xW`qZXy3dw!#ld^ps0kQ>D4W{WU^N(S3ttrV(xZa~v_D)GyK zYSP*dk(hNA8?n)5G8CIF*^s?Mth~#%gh1+t>cX$BykiBFor35Bxx2z{YQ1f#o6+^r zV&cNRsq(Kr;3E`jXw{e>)H27^b_YH2E`Xi2ap@ez|_1 zd6Ee27@Y?nlE(=K_icsMCZUAIDrHE0ypVklYIIb3TQVii6tXo=?Czp!eywb7PF0y_ z2IvGX4wkZAB4V0n;Dhjb7jjJZd4;! zir|{~umIw`>O{g(92&pg9|r_gnGd*pXejw!xqop62>g>m%m2C3BP788&xM}D)SkF$ zKH~6S0k~tvQQ208-{qlt71hLPpU5kNXx)w$L}u!^mvm8`G*}E#`rUSmrAyn<{fzFl zmVP@A_ur+}-Y!uOimpwO>NRnk{>ts*D_4l>GUeR}%jj#ruWdQ;c{ok{B`W#scXKAW z$q(V;Clgxp%G+nv{c5}5t@Y1l#Z>eG-^sW;tqY4|R9jncRlXa2xzhC# zH_R|<{>R5GRI?M*PodFd3N$7_QzE2ZkU-LweO4Kh>?a(V<#F8e)`-^3A?)Ub(S8*# ze@=dbn7H#-Y@Il|68195op^gmeaZb}R^i=Rt~gSuCgOStP2}&Od;JE=k_#xW4S}io}xFO=((E z4Sr&=5_`EdIU-go51Kp~@RZgVfMi|OMXOuIFSV!@qMcw=c%-NPH^yo?IVkK=sQumlAgxw@V`pRg56zzc z;H?(q7ZBoo?>6<&Uc(yaoUDRV&Ihn1TZrXox`bh$AKQ!9#kbybVtC=+fltkndJ zDu+XS*V0r{V3f-2#$TGS1xehwHcaa~ktCQQFfvBFJ#8Ga6Nt(kNeGUX3J$%Tm@Y&I zv)30U&(e!}1*(-5q6Z-h6+%^25a953EqOL02nUQD0oOthg!Q0@)*>MsnoM0dvIj6k zUSnnOz_8oG{>8<)p9*99%fq34DB99 zWzxocu9}KT2nbjj;hTq&;DqYxyaUm*q{tSj=!@B?LO~SypfVi@8;rOzbca-e!yc)t zT(g!5Elg%#G>|62m4k3Vfu$e>MrdP#G6x=xP9YK-SV{`9!}uh$%aih|2!?=_af7uW z7=`)hg^fr+nJh#{rLhr-4OU6eMM1$TNG$_p5V8mYN+djZoQeYqNXn( z{ee`iPUHs^K!D}Re8CpXOh z3`LQJkmVuvWId2Dgeo4q6odnyge5;(gbgaAFT~#ikirmvcD%$Le5E7z~g8>0Um>&7D>{F2{2taAn4ITuDEJVC24GIynxxpcR_qcvsJRUrHj4Ch= z5Ro7$XW5?sN&z54NpXlJ`5bvD9vxT;A_dF5d(#6*1|e100Xd|2P!JM8LiV)J6ba>k z%0ig%3sPyzRpC%jq&==3PZ=N)i;k%^$UGRc3hk8u1vNTJmcJwnyfr$1K>I!tclH%58<3cg$De%?pasfgj4$71K>!OZ zFDTTXD!_mU3-n1C*#{C<0EHK55(>B|FXe;+o7?du^az8^TcP04$KYtP4Ua|C28Skd zaX=oAnPgD;cbVFB$Qu~V{^PD9pkQ4b11L}nF>D|z%~A#;RY}pwOCIJD(t=`q0+8gT zfRL!X#7U|`R4+DIt3ykrDHMrI+6PEJ&XkGte_?{Y2O`CJ4843796-9fkdqVrnjI)d zr6;fql!k)ZDHk+S-~hgWQMd^zNacT6K=27~cov;8hZtDHl&1s*;b546kW5Z{BJDJQ z3xw*P$D<9fW;$*Fu^tkKBOl5E0VY7%RF=7kEzwJABgN=aqctv}U?`{*gmg|O!T1B9 z3{4^aDg?s;rhv*xORRR=5_&9;?Km2Nh_vrU8Nd$D_unIg}B$V@Y3a}dj z_Q=%o)ykKJva>vHgZerEiHidEakB-0;(DHnLpdA@lYFt+z#b&%%zcHy81fsXpbVj8`GRhC|SSG71(z2wzAIF6nR-LoY4@8lg-qZyQP$4r3rM zBnDnWFqmjUQBc{!Qk^0w-$iNyTO>~~;CW%Pd4*~r>hbs|KtZ%1Yw^1zTVU9nL&RU>Tq`A3&;- zK|rt{N(=?cq??iwpz;#Tvw)3d9FONW69}vTNFx;4wcPY%v0vTkz|Yh4dvE>r6s)eeY{~GuSbH5 zRBV<>E%l3Fw!#!Pbzdk(BPn7%;d{G@mH~v%j<3BCiNg&*swAl5zzyIoiBAo{w0Nbk zMqFTLFen*RW`}_<1!{ceh-mJY#sSESKl=OeX3|!LKE-TA+VqK;hhU)aU?~CIcKT3E zh9-PqaG!pmf-eq;6ejV8 zPEwOFv?^}kd71gOP+Uq?GYhavh{|_}>?Bmbmnls-K}bl!K^UX37wXb$phBw53eh7~ z?IbNsH(11Z{K(0P^0E30TC6Pa%(fd5g`>flnS15)&n=`=73McxjfGn$G0sz(7=$68&{zZ{IDqj%)M*%I_?*~`7Oz~bM5v!a6TXiU zzyzqiZ{XA*GxBL^0;1PF|;4+!BBz?TXqjbKI91LPo($S};j9=Vr^)OulI&mV%Y z(SrU}=JQV-K>s4i{r9pQzrg>mup_`P_}H){@PCrn;fghJ=#?Q>co>!354j~TAhGHx zR`e&9W7c6~`e+qN@-2Mj$>pl|t6Q%ZHpj~bK|uzH55={{XO~GL7}UK9qG6#{6c_rD zxX3@mO*wjzpiGG7RJ3eB*)Ck#|(vFdadD#g$c1Wz=o<+}| z*!iX9&Coi~%j{1kcoyh7sfONJj@-l?tw)ZxG~*3%S)(Uw>3}SI=wguUFtMa_nj!xP zD|BT z|MlYFe-7{YA6zB^BEtU*GNn>q-DOjZ^bwizyB5zN#-Q%*Xc?7sp~2femG1o1eRm;* zi=Uzdtj(;G^eyX(^-wd1DMmrNF^ix!28|yIwW`~ntnj;Tq5T zw@OAfS}C4i`-vo%hAr3c3g_F_{I|c~NlFLyobtn+Y;xb;T@Kyy#=i`_|9$d%+u>qj^Cb(G z7~tFBl)#JrgT79rV|p zD@EYA6~a%tUGh5@OV>E2{yC0o-;eCY%?4e2u*A-&z-zMK>ANUQVxAfnIwf;${{5gZ zLKBFqy;GUuZ9P=f(if0&&$utfe47bVssEyKVCxxFV|JL8rpBQy80Btyt>4zS&W^rH zw;PgXTcJML>}c*ezJG_AD_LMTN2y@1nOyl-N=0w+sa1$8W!=5%-psqruN-Z)6^2+8 zdakqBETTK@?yWAe*8%Hu%x}gI1pVC&2A38Bn;tl0{f6u$hO=z>pV~OCL9pZ$d2=mL ziWTIQF}7_Vr792%?2v*k(1FR9k__SN@EyDC>B*cwyclto&g2ErtM>Z)`kj3lgKP@% zDyjR}lwa~wvPQx=ZR}~d8I7!e3YuuCFf42P9sCwOA??n$h^{!n*t8rzz5=U!pxnfh z={O)$h<>JV_)AVa=#}O9Q%VI|&@h&AezUb#%M2vqrNVj22$C^!*yqa2u*4up_=BI1 zpya$f@w2IIS+z>@X-@CdUoCts5|Qm&O;P}9^qKXA43i|g#--ob$h3Edr}pcn4V2!m zKO)I$Kg5i8mb@GI^leVw+?E6M6)oXuj4ouN-QB`FqCp{>=~wbyJggz59xb2`ocm;a zZjFoid7tW{$nw}uOl(yj)t{n(xAXGT0y4!?OJU+a1)kBI|IiQZ26~D80@cls%dn4S zQ2S8EHvJMO4RSD?Ukh1HT(hlvJ>T7*g8v92JWJqo5B4v7%bZ~6>}x~*>Fu@8tgEGx zc|x_Pg!IC@8EOz%&i*3Fy4%e*@FifpWnOCC0i1?$+8}3#GXr=hAy`iTv>4sS@$wMK z8yH9U@!h;W0y4Tr0%#O(NlybB!co_F2IS9O@Tu zSKu2i$%+i4N14ntbH+OF-|3GTxLa>T_ce-`>=@+L%@-@<^z- z=6yAnaFXB5lgDARA1HhG%$ePl=Ls6<6MnGG?ikkkD-Lp;BFEwGT=XU~P`~h6EMq5) z7X3SpF2IbSzJd(kMF{{rDYZu{$*&Hm_om!Aq|iAz^JvH<`hrlIlT28_8Ae;)dlT)` zZ6DqekWcsylWm`(Qme0*?^%h0H3gK-_~-HGJfY| z%6YKiBS{5lCHQi(k2N=TO`xv<+htUp{+iJ9O5b_|Yux9qCC+A5wZG(Q(wV+yWyqRI zVZ{A@i{Vqg^=^eMTu$9K+;CpJgbnT(x`nGym+^9Q0Yh}MR#ihP<4;s}o1fW>KDvf= z-}P|954H+y1Sj-EzF9tKLO2?ZO{`?_yt-_rwLF{H6l;aiW&~_K7 zbYp3dg=Wb!!DL_T+cV7lf<&aKoR}w&V8Z;tx;rK((~Fxzn3TOH3g0XRJNUPZTfg%F zB}U7ywC_Jz;PjPs0)bi&J=2H6$56)~x$>*F-_>8@q9lfFyZAvy3suirJV}2Ig$_mS zlM;bRnzES(@+!_@pWl|!a&)_h9>OEwMGbIyFT)y7Kl%V%PXKn3#VMD3Mp9lIlX67I zRI45x?*9EGL};$!U7g{t61<}(!!i^9J5K>K7uO7Vd`H~zaMq5R%|5q?+yHlz=mVUZ zjDU5FgM;OK3A`ZQ&pK^xl|*o(6@~)@`qn z3Rnijjqj1KO%1@cX<-uI_VM41*)_XUF=1!+K4GAJ<@+I3F}+dKgqm(pMl>qgrcq9V zdgEf8wkAhgQ+8Q@4C%U?{w=;npY2k=SuV4K{XmHjWWBXtRliBhh(h-XXaco*cY@Rs?FJrZap9Do71G;Z^=SJU6 z!bg6UTWXiGA0M>2cFE3~KQ~tdyyJ0?#r^Y@Uon?Mtx2+=T1(k;;s0apEyLsL)htmn zQ_K)!OffSv#mvmi%*-4oW{fd0GdspIb7GF!Wr|~FwySFDJEy1n+_~TMblyU2ZATDaG{lAGbvjkvA|Tf>BBo7OyKmji-gOo#JLQ;ptIf_SvX>M%Wy}<070xWa z{WQEQ!;_Kx)dgdJ`DVgkPOqN1u(LUJIm^K0x70fu6!z$jsg@;Ma<}O%fv;)5h-=(` zkY!~bdN?lTPamr6F~7a{AAVB!fcMRu=-K7?)plRsG#xG)#wYh0ce(E)D}id@PUOv0 zePaD5pMkAlU}-dftU7%!W*wvwP`&*Gop+ZGZQ*H$&PH>@4Z9^0oM@Nz3-K+ z6i0~AbFVj+dzeZs#$KkIoe8UlNvT5Ga+2Vu$#;WeNJ?69Mlv<}+7hSZh8vXMv%N6U zRW@Ui&|R$88&=MucTq?$*6A3G<~4%v1HD&jl28S{hK#WEbkko#oE*j=>pF#Por5-6 z-(vBcxtU;ru|(q&?Y=^mRL*KrG1msZ+jQ>@P_p}i%X+*o$zWcu1fH~@D1cirw9>zu zI{yo(`)`&T|29YE;o|rYM~8gEX;zc#I@G7;CyqE!|K|OB6*APc$bQM$5!n<3mySq! z_IC3`T6t1UkR$msR8vMXi&a2%8IIf=o;uC%0o1kUKB8hr3=IFEq`y=>3Y_U$B7k&;doZoN*UHs`URd|`ZdYqKu3rp%#Cx4qH z8xXyJaFzF*_Co%=zV>Lyym+yy=@yc3g;B5~>0u?8;%VJb=)O7onItbI2h1ve$vNs= zx1ofW9~q65;YZC@^6jgscJaO!`n1zaZCh=up-64pvTGgYGu*E`%Y_W%RN^v^yy0aJ zIhM|z&}4yNc`-D`!+<=z$nDOC-st1ZjSybfvapcAo)jO5c&D|a&Mq77DP?W$N(Tzw7CL!O8#My43i2c>YT+HEv$+e_yCV4+Z%zxYVFe{*RFc zfiNd}bwxqe3%$+#bw3a&?XuIhW@lf=1_R9Hx?8O0zOtgV!3<16I6rAHxA#FTCh zw?U9rbF7Vlmr0AbTEW3lgtxDTt6Jr-W$ho(KW*et95RfE9&|+EAy$t{DHb6TUgUtf@q9EK7`+$j9yf&e%lZT zz735$~S8BjW?Ky0B9Bh#_1yOIENjTL@dy@4id&x+ui8gCgetm0s0!jFa z^1MI{0x1h50*-tf#~fv=m_%cUgkjfvY1`RK+?;>{1!c9f= zyD}9J^^k{Hm6)FB1h`!yyd>;WvQ=0WX%UNudM>WIpu(?nzm;R9f*d`F4S|5qFTYqI zHN^l<);vzdi%;plxZeZs3r_7`a z%EJDU2GNC~umT_uAT)c>2zu+e0eUIzIL?OtoIoIlL=XT3#Qw_qS84!K5CPBs$I)D# z5LyKqUN6)|#Q#{D2SF|LoiN0P zf5z@Wnh?mN*Ea8M6Gy{%8m-4F{1SM1p>Y ze+G0g00OxLY`4|}!k*DVh-J`6LzK)QkTMA3AhdMYGc^CN017K^ZF^aM{YOQ!M}24~ z+(8c60FWK_=LgWm1?0$r{Z$+c00VX|0T6Hp)NpVYB4HFh5j9+z-9awv z%pefJIRKgmVraO7*2w}`KrSFq>;5e)zifs(MA3~Q`fQ5tkV zV}Agz#W=H<|FK(#455f`4=x~SMUEA(BS1h6&$S!?ec5l{GfzJu&+!qv|`g!N!&M82q_#(iWO%>!b+p zn2J}GnUHWS;L&u4cq8dY!WYNurc##qw5W62xG-5K#G!wXKj-)~;ckseky5PrTo_|V zPb_JOX*L=)h?Aj>%StJ>Mykkiu(yt2xSr;M=UB+Hg{@=RE_NjIoY5zPO*8l*v4+hG zXRat{uZmE-kfo`-jXUr=10U6VmE|i;4r$uBEy8T`cqPKDLzV=GAaa&OxF^Ph^;;}U zg|7a$L+lgZ8{ZCCs?5XAvV6(M3SXt4qO!+iu!e2rEnQ(}Lkc%OQ^qVQR0*F8=ayNU zqr#LlR7r?jaY?ss#H(SO%0W`N(nXyB?yLfvkl3p0yK_npc* z%E2B9DWZ;bG>VpIN;sm}X3BdfjhJO%{4>&Lqhc;8LJJ;aO#BXBjBq=a+5Tx%2y-2h z1YVYXB}v2px#I7V6pXS`cWxvF<6qZzdc}6NNvh%F$T)@ zF2~^!tHN#3=#?x|(6dVHNg{7(B*qa~Ws13xa|vQ4JeN?AEa;_5QgGT7VWz!kEK=S( zNsv^fgpZRr=^}+4Xk&zLv}M9NqaNrYM{zpkecf)yNN-CKfdbC#35Rrj4Ur=TobupS zh+a>NvEigvs0jO$fA;rZ+OS5b8xi`}Vk2p;Xd<==&tfB5u3~UWr1`pw5GdK`TSwOB z!S1vgF_olYdbT6S&bbI!M+9iX4=E=T&NogVrxlBkesoahvX{Ze_%&BS{k~cJe}K_J%@tpjiJcUxEh~LcA?m=9N||M z=}mBJBoXZ`TEa==CyHFaJiZV3@uu0o9KqYyjJc;qX%|~Vhs{w1yFKI#6TZ=4p5ch*q3KC3;Z$lcwQPelCv39l<%6gcn56(kZA z%eg)V<5!?AwrZZlX6;HilT8u2gAZ4(GYOY(vd>N%PMnrSwZ-3H?#<_j!y>^{4XF(L z-RJu|6f^h#e<A?zBpRs+GM3Q%~jXqWs-&iKT{nD-p2ROSB*FUXN@3 z0DIrt>iInUI68l-3C@~PeefIm?Gct!w_RjJ&NnsGwK9Dl&@|oAeRjIIQFj_D78Jge zkPF$`1c^X+jZXQxpVU@8x;JJ5+8KSjfBS+}7kTTyaIND{?Y&YPH4G)R)Utm+Z`r)& zU~kzKSSGl*vNOqKt&@Fs*kj*v_r~>=duo{~hk5f-ckeugyCm9~;C(s)v-(eYq#&%+ zEQ7uT06~R~0s!+tuHnTwl~N+!oF=gPWPm<}YG!M+r7|XT*d?F{1iG#04sIXM9yz^5 zzFcm@ud{D(x9or{V6Y-Xsbri`@FkH0I%O1WwyteSAaQgDYM4ej+SL} zfWLXi7153Tmo9^Q)wQ^)uxh6YCAYp$nehFgY=q?&>PRSf8e&i5H@ODvKPQC=c_YXLMbhY3RVUrXQReIY`23kV z-VD(TaH5E{tk{js1+|N&Y(k>AU&2E?mtg5sq}9`Ml8<*(9`C*H5I@k*A%0 z@UVy@RumN#q5Ai6=WbJv`XFMbfyrRU2&2%I$53q%1EvEGuaF24S}~)g;qV(LAj7dj zOzzmpz6?74_Mh;0P|KSMR2YH^`j=g7NrxwY8R3YjchRAJL;5;$D>00z{)Zq-kzL}e zr4*LqJPKTK{b^gcN^;Xt`XN61!)f42u?)IwL@CM~6ph!DHsRKCU31r82- zQR#)oZzgZzoKug{Vs;bxO>~$zyLb!c3V!ghyAtSsrnt!!NGEP*5z@6vY#UZmreks2 zg-BWbEDmK-&rJyRST}5f9(IxW=*Xz|A-t8TfL|)>I;%{V8tF_tU90MbpEIw;i|e zAJxW`iA2|!vAYg?!)f+^>Aq`=RF2iP_6ck){4GLHKuukdNy$JdXR?yj$L1_d&}+0- zxXN^l3J&)U%gj0*V8u*r8H#!Oo@<=HIB&);V|$b}aCLA}a_2po@b*=Om&10?LCn~F z`uNn1)qVKhCcj(l9ZWc7d6bs^R) z@MhvhVyYg$)tMG0gAc6MvoYI!UPLQl80#*w+b%OE6*2k}o(V`fk2!)3`2=lBR*x zlP;Ct*ecX}b+NTL5e8$Ld}4EjuzaO&jCvT=TocY4?t~>QaLIyL&9ROK;uPM0%(n>C zqAO7T^pU2@awrYSdnC6aU=_u2RpWZtUCnFIC{yvFv1ERNek6~91dfvOzXYd8)t3d!1F`YYe-DA&rwTtc{%J@w z5EI@6u3iHLN*ZPoc>+~_%!L3Dwo@^KC1o;*Qqxf!4;zeCi${LQ7jP>#)j9NAQEqzQ zpv~=13+@tu??AMz0k}gwS>vI~FSjmLpVvdVKp~E`62P6E{!^5Wv2kBK<}~P5vlbJj zGV_;VlCr9Knc(o(N3E;|EPPLX;aYe6q4}=a9ckA>+_f+M&oT|GVgtSeKM7L;95^)u z=%tIzUDnKtv-*_Y6HmK#!lB=7C_-j4wYvS*Xx!ZQ>F?>!$IR7CWA=x1EB%XRH<4gJ z#|z>UVXET=czh-XndcfYp?33rb>luhn7p^o&Xq`trlmqAm>{zzGK7xwTiQ!o}U z6R{2S6jejbn)i)Yt{z=i7vc|k=TGNqZ!)to+qob{+}--QObpp`*Y6dd$RJe1t;Hil zXmB*%3tiNKDM|G~S_MHxlr8e2)tK=!eyq+6hJgvJO(S+^SugCLpA*jI0_Vx@ODdJ) zvR5hg3S4ZsdR;}i5WO^?hah>Yb&w)RW+|J$@RihEuJFc z`e4jq!M|YUMqZiDdFx3o+KQ#jLkT6L{p_wsrVRV_?P$-lzYGNm8w~)+Svp<}nuOT= zx|eL*k-`Y#Kpwu1=Lc?~R%TG3MF7aDb|`0m02gQFdv99^7ZR(BxJM#>0WGb6#e6bb2aC$ z;)Pc(=DzQM_Gnjk3ww0rFekw7m=T_3cn$(^Cr!Csl#vDXuXsb+mn)0Ti%WFhY!@vb z6g-dFSADtvkS9Nx^Q*%%MHnR;7MwJ?zK#DTh*tJzs6$cmxQN_yEVx$T5Btmd$MRRI zb-HfqeU&3%dC80u7WV<&D6b^#+jyQdD(N)K&=BE+kem36(v7+Zw0K8u{AqnHT|S;Q zp%GfbpE3X8c6NHd*{&5%L%a$C38*V|4h-wWmb8LD21WN&-r%*s=4}_Ra{l9Ny%`r_Jgjc)bB~zk;u}LIKlBR8;f7ohqT<7n23f- z_feXZWm!U`ds!32WAVKe4F_vR*RqC2G+Kwd%gR=W`MVz@%w!UjHa|XqVLrWGjA4z; z??seafZ1J>tf}~>@fE3^x)s|>HktWWn7$_W!ZQ0)NS)0ezJXhY%?wXutl@A8p2d|t?Yp9(uBdrB1LbeL!nOv{f#zmhxL9rs)oBd9F)l8ZNF&!Ve~5dNx#Oy zJi}cBISQJ{OofxMAHW9=aUHE`*ZiS>)*n7MaNxDM@5~r8OS^qqGU>S@&t{L{Eln{y zxo(K{N4YG6oIL=gJ>={v?ZEL0?r)YkHi*&tL^nZtP)E(8Y`NUuEFHN2bMn)FkFLq9 z@!z0p{&j-+|Mqk#hkJsxzq9sV_&fhj*W~^S+V=m6uK8~OEE_L3A5SYAFDl-DV!QtP z3!VQQT=zc_S^vvI2YR~HXR^*WAV`XXN6)k#Qf>wXB;Gi5#4f;cWdK(nhMae6TLI`ied|i3Z*YbiD-dv5h9j&G~kj zxQz{qyk(ua$!N=Fh@mQ981ZfL+U3=^`kg@_E2l z=DuOf%=`h)kA1Xb#Jh8%#0y*LJg!CC$5+_sAla;>Dfv;YSegv~)}`@jlPuvt%7Gf}j5?Nw2( zvm83XH!|EQ+gedRb)cjVcCN!SP!(IvB+?tZ5!+!D@D_`ktM{Qcb6LpM z#Og%X%b=L)PayRhOP5zhDuHVLpj`HJyIq?uwvVt*9bZbLT?<3rV~P6#_h%|Xo~bVy zk#1v*drGw;&#AU-?9ia{ ztZznq_(Q0~{&>_IY3W)0TTJDEDjR0n(rmDd~Zv2f?#rhuRtG=1SQQYLq zPVrhRBa;JfIs=tV@pz21bwrfg#s$%gp?%Uk#$h3yMQ;*$dF~DCJ{o!C*sqsI15EuK z3(i7+lsx-^`aC!H-AoXaAVW`O-@SobQf zgFQCE-F9m!jsR$*$ShybA;XBOPN$1yfVZPKbHvyF+_J2yGpb% zqmxrIcO*u{6{?BNQmQ`!(*HkR6Ey7;bpOlpH&w7gd&;s zofy{7skF67gTz1!(zNWN2`dh*aHZ$&-jQaUM-6jTl@;_J0{g(=XY2|=|%bLA`9$J@$?z37SV@ft+olGqGd zmAUFkw)d7oeMAD>0dR`2Gx)UwrlpNARe`~-({{{S8%Z83i?KkJ{U)U<)-UGqOuYA< ziBzQ@gKg2?{(M64LY9@q8dJS29-Qv)8wh_1EOMcDHuzYMvDSk=uf~eC5!$rndovJq zxYqVfkI`*u>i&Ao!sdfp8pXAbFP6?Scm9zNXzm>S$GF5NtGclG-1DFFUtR)NjO=rN z`m8hwwE_{u;wZ|#op1aEatWzNeHTk4PO>+sDgK^qZDcnml=G*`y(2MT)*e~nB$(4g zKPN!9SvqI3orI7x|Ks4d2qguRZCCmAc9Uxw`k9i!+NN(oz7oNZ8}1N`uRp(-Z>`w| z!&f4y_+jzMdB&esh%MFOZDTKR)k&oA$*Jm#kO9sS+L%1}V*bQF(}mGEIqqtqug7{4 z_DeDIjJtT!zU{whKG=y`?zAK2*3{NQHj`c9i}z4@8SgKD@7tOi70s-9uBB$>I)%{} zdHlg2Th*%#TlkbayNs;p_x-&pCk$%tpEqBr3(;Jx)rOP^zy6wPs~SAKgWU zs#8~q%ZggMZH^_Dob4V{+H2b7J8nm6@3D81Bw_%*>=GG0msLHv+Rp9alF4h)v0yO=r>uSKob_T-wfy8MHnGP^7wjQK>jKs&(!BZO9MSFf391 z^5G6{j3&S2X_1P5fz<{w4Mgy0g%`347)#-*&@bEn1vxbNAk}S~Tb>@MieQ0@Y}RmK zSuxF!8>A6q)@+wbleZqFg7p)SepsfVD&BNF@5xTX3~Ik3o6?Qia=2tp1(bVjZdR9! zQ4z;5-7Tn1I1MK~A%9RmrlH0!Q2&TUu9<#ySorOYUJ0stXd8f|;ZI8@)dTHzcxt(K zg1hXc*HcKhp3BkNVX>dwcBN$v1o=GNrcM}| zT|OuRM4w1-YrbVDmw0*bhTOU~HHrED0Kd3aBO3*3xmQm*pT9b;8t2R>U;5t&bOpg5 zbt6CbUQ`cECtYmyntnFh1W;`DhL|Nei9T~M9bCHRee!x_)m>QiFJv#j83{CoiNAgd zG;aIW9p@|7cOe%>)*E1#{d0O=FZo^nbD&kWP=$s7K2NqtqBO7~mt5=D7X53rp0%w&UptOEr5-1Pl}m(4aD zIVKUexUC$^LR`06oJ?ruUYnOtS*wfDyeY1YE{|ebbS&qmz5XDCdjp|gi1U49X>V^f zB7T?8*K6Y5Iq6o{O+i&2@Prv%Z&A{+PM7fNN!K`4-r>#z5K+ZO59_KP<`F8?rQzod zD2e7YG}tG#<~v~zG6J-#*CCk_dqH4H@)dj^0gJRWgE z@Kfe)co+}{wDdYj6ZJeF8a#};m%GQJx{fc>e<{~{ZXm*CIU5NZuPn_l=(AypoZ$LS zBH1WOtS4bf+MzsWrOtT?MlCCt#C~0nUIOWT-7^b(cVD#fPmKrV6sX_>n0cIqtXm_c z578y0;fbanX5nIFiaVM>L9XfKN(YA8{B$>6kE8iZ{(m*;w4 z95h(KedZsB0dq=79UPUsJmjQOhVNeN4CF7v`ydX^Mj!xp9pb0~`WlrWW!0!bPA8H$ z%PTRE>OtzNwou3s9Qw=iELAqRE1C%StR6vk0Zv7#`{)8VVw54}(*}T3B-sIm18_=^ z>ga69ZG^4Oy=-PbBz+9m$s4>*dfmv2BNJBy+88q23J`ygEwibT){`QmCdl?phFf$DoFP{ zl;wT#@+^Ov5gB%N#0r&`xCp&XE5%a-Z&Fp!F1wvAInWb9yt8oC0r%iYNhvcBlQwS5 z!-Hv>1^Y`_bOs$Hu#OHLcwwl78rnX;NLzyFb`L?N#WA~oo}WXCE;@Ubr&9uSY~d*4 zn1Nf<)Rn-%9XeE#3%6c`62xou(DfraNw|9yacH`d?Wnq=@iHw^Z!pXcD@6f;@D|h? zrJg2{y>QgqODpNIg9uq)A4DUd(rX3vb}86Xndv~&I?+VSoBED8P^k+eh?~U@ZDZtp zFMs0^lROpR#h+u5D(&awh(t&A+4eK4rZ7Gxs^5bfcJ!N7-86}KTX|(>;MsZF2g`vw zVK_{QtlO;fs2wWFGKnY@cJLDyYBY{&27N@>{gpQEEsh1w)S9TC+pI!q2}dN*bm%ft zyOJIkQiDB@mL`dW`}`8DSzV4}2I@s-X+hq(cz`bQuQqTjXxQ4un!3J`;W z){j^I1h>RZj(td&6@}ad}mR> zzR7+VOWB)d@x{*{tm$UdXf{bV%u3;?kZ-r>-n0acLhL zlUhhS3m15QjU3yEYu2cAS2Hkgs4p4f=1YqHf)XY73*3h7oo$GCan_B(ZUVPn=6FVSypBT-DW6Z3CLICciuO3uDAh^xZiO!3`PeVU>R4A{aQUvKms_3Rj zO0TImK*xc0m1GRLQE?O+YK$ZzvCa-05*h(54zmu` z-O>Ggc_w1@*`Rt9S_-@FGb+`fdW30*d*oMibxEf%gbYs%Qg$O(6sRgl+ePVj$-q~t zOsdFV>n;emJ8;X7RGEtKO5j-qmPbcN=s6=L==Vn~OlIy`=)ggJCbMO}A_41+?}uhtRgIm$t2!)wwRQ205^0T__o~XSJPYo2={y&AxcW#*-J}$K8E+1P9aWfX6q-GGsB1&S8F$khZZ?CpaZjXVg5t^h}3&A!C1VBNU{$xjyzJ@W-@H zGW;i%f>zGV(^O^_ufgE>lu(4o$7j~B(ZPIQIeWZ`61$X(arliC85aSyJx)%gYNn#u z!BKjoV*5oH;fAi7(ao76bL8vjr73GXd1NP1J{=WO3;z z0GkL85F?I*_dpH0LEnL@lJa*`nSa68|L^Cm+&p~$F_n-{(AHFOi{ilOy?d3|4Jhy!iznSh75K{$4c6T!%&tv7;HM&}Hk-g)N z{n0Z2n{F)>fIxAPq_m0jN0YC74J!UlNC+d%WF6wj(ar`H*->MJ6|(|KpTudj4K$?N ze(@H&K`!X32%$U9;ws9tI!2ZAJ;oe}gH8iZgY+gGF{Mnj729u_du;mf`>gx$ zT_y~DwQYpxltJhqP^Y?a|ey1k}=D9((7BwSn6AT$6G)hIKy{t zNFg2BEok{Dhr(yqS(!Gb;M-%;Oyv|yAJmgIl#V`+$8yh0T<(HUR-r194V=q#_ zwRiU>hV#<$qTdVnWvWt$_X5}4yy2!pYo<#y*~*_f18<@T!1jZQ(tN@ zSUW9hbcvX}L8XTkFY)pYxLF)9tb6upH8rW6b3((6UJ&Nl8m+#k%g=?F(8{LAS-yA7 zC;zRi!N&4iRYNf`PgSFk%aKs!N2bP&O@_EE;`}04qTx@?R15<4Z)uE2p zJU$o;5_?v)I(YIW~d#5TYvR%6bVNGlS$W1)jr#%z88`4EAY zs2b_kA>Ogh&#Rx`?S+i-CqV|@uN-xHC3@8aM+T<`kJGK##G!wW?JV>R;iC>76zm}3 zB`Tw2pSQm2VXSG&7Hb^V7dN-SHr?^yT;$}K{7M?+-$-8jBEZ}YzM~*~C)hGy(c=== zo^H&VjipV(A~-rE*Wm>W;YK5{*4XH{9nuVCv7n-2WLxnvbQ|y8*^2lB7klyu>)N{W z4X?mQ_9J5c1f9#0%SL1GURAhs+6@Hpu$}KFq;@wA-)h3F|sIiX7 zcP8j@g$Yg_kn*0VtqPC#xp^6_$=u!1R-b;fG08%{P?hsd;8z75lGnd?oD2M)f_nH(Xi?AZ&9WD){Q6i&iUJv(J%XE7stY&Kok^y6nMCqD?gkUb1t#p<`s(t8*wIiUq~P=rKzGtH z0~0FOMf7UD4Qm?gXshe!sAucD ze+-<)6zQV0wuCoE^#e3e2)20wy^zH=T~v4JD=wj5Nk|Clf`6@rTFbF@I;T!6+n#13 zdRjCFUU=@^)~`^>SE6%~9tR~$R~(mmJz^!7GSv%+&oKl}peV0%wLD1VU%sPOa@g16U<63{Pr9jkO-`Sgx0ubT-1u zOhB}p>&NvhExVWXSyV)oyMt;G4Xg$%5_6sG3|X$LSe;=0$xxkkyY-)STGt;tWYOur z+Z=Fte{gk@-JsdDZEv)1^tF3((IjPJN{Kw~G>sn>D{qW)@-EoQhc(tclKdQX+Gl`Z0f;IWj&-$mMW> zktI73F=?uhSqug2Hi`m{N0D;T_?rekTv0Tq|%@AqpdS2S0i;k3oS2{b@M+ zC|L8+CdZ?@fN_Xhi~Dul0Wdl(o4`;GkW zYr>hB%kgx}PF>r+3Im>K?bjb8qlC;_|uBIW|kPQ1DdtYljvrqAr zyNT0zCflU>pza9;yu5I(rEX`bkqCB~+0f~byY>AeOoZo*V4wF>)uzH&)taSgIx4*92OB8L+cdEs%`kKS(b@`D zitdTpIDE36>HRvm1_vZwU0q6WiqwVz6zEf*z!=7S^Hwsu81EI1BQ641_OkRyiQasE z!%bu`%R#HSy72?s2tt!{m74HkrFF&i2JZ@k)-&X8>>|BGO?d+(XBDp!5;KSMA!W#x zZTNebdNh}6XvRhqjwd$p7Z|GpA|n2+CXFqK3J3;PoZEw6eC^l`9@!;S+@DBNe9!cY z%Nvx*%wm@GosdES9yWeDoXl!>-vWoQ4n~si4cN<~_|v+6czze-?TBA_Kcc_q4Me9u z;$-L$fRqaVdXKrDQDcsmESQ!xpo=H`-Hejnkl({&b-NDoL_ND1pJK`8I!^UBaK7% z*sOd_AAH-jDVExgH|v&+Pe%v+P-Qc7Z|b&!UlTPAXWsu*Z2J{jZt%<)s{>@D8isJAt!J^qZ+lfrT=CAb&p;PQ&)ntT>3IIkz? zm5P!5rNDI$)H+pS<#ZWTQxxeU@+Er4zVcu)PtjydHcSG;`4NoDa;j%?OR-G|31@D$ zpceIPvwJ(szfQQ^Jl{bYuk@8hTM?0L_n9t>BgfzlRqiXS(`JO(h7uDPFeBE-1YNsW1=tZS)OpR!-CwN5Hgxcn|Eq=jyiL}uF67J z3jzN4+?V1Crm~Wqd9IZ#07LnGkD?FKLL*yp^;U@O{RAt7^t9FYJ$X!P>LurN$cFyb zj6z?ku3|SRSzUPWgI&c?zEof-Ix@<0s4ar343{rs&Xh%meNXYeSTby+?&!?2OjJ+Q z2r21hYA55T>ihXGGFSR~dp44Zr=?&zHYi%CZA^>kN!(ZH*zEC~N?2bK(>5RoQ zP6c=yWRsRnYLt#7EbJ`?-0TN@RM4GPBFc>HrCq|G{$)xQjflaqn6Z~!YwanfgWQ=$ zi^i_*3XIlE*xRV)nLBl9JqLxNuJj@bA4z*HvKlz*x|$-T=Pb-lO0os+`@S*S6|BJ) zLgpXw@#Awe(m$%XkH0oFT&DHDb+nY$|Ij+Bc3qO3eDY=uG}u(qj-H9h>TK3MTu~}$ z8&7h`ce~`d(qWR7k&^ahT!y)gHR^2B@28jU-oho}owX=;L!6&=A8eDLJ%>fi6M60nw~b%$E<7u- zr7{nrJ!s-ZJ5O&`a)@}_~-vos+!v(k;HxlN&l#T7WI z}UhG{hR})p*YG!b|Ld83p_fCwAcBKt=pv&ZS_AV`>p}YLs0Hn{x^yyo< zcjfgR4IOOeb>F1UDT_Hk1@;Vk4_a8wW4j7OxuLQOu$c>2|dFNqHocOGQ9!EmNt_b zuGMf)znO<8jG3r_EH3;Hgbrk;+aqc7NolQbg+v+5G$g<(gzCW_QoR@CJvg6NyiIyS z^~_eaESb&CDB6~-N8h(Bs8Rg+wY!R&+IslTPqucus8Lypbih$YQiPphP8FSzr>(l+ zomZN>k)J_bE0}v?FESSx^s#hrn{MmQtB{MZk|31lTrNTob?j!MLouzET{Gs$+xO%I zQPiuRE48$RkX%VmE|Fmu0pEkZw>>IWXmQ!l$i^JwH)qq^UnSvMsMk`SG?EewYIckK zbF<;4@X4=81)Cx?8&Rt}-T_mRqs$yz+$ryl7&cDhP&=lgmm6eeVEI#*R12$T>a>v- z{812X=w<(;#HU)Oayxy3lHxx|f77@2NurEXN-=ZvOd4lO%CBT&=~bX`vmOUr_O|{I zR`=iuZ__0CLH3DopHsqAYhMxvs|Q2O@cHe`m}XG3m1Xm+>(ja5N-$!Oeh-UWZ1ym6 zKyX$lX8QomX8+|Tx6zfG0rP1d$`_A;FgYTiYe+DNr%}*rHDP6Ti|VqlLlH=ewpKc zw0GrEO`KaCMQsh>(q}<^N@IkAz%W}fBq0e~f*?jPghho{CdmW>*$9b=pyCQv(O1zT z746f?@>HO%s89=Ttte{UwTjk*%JYeJ;Q;Qw31JDE_SAEp|1p0}zVF`O{l0te%*nlH z?srvRj@oexDruP<#A=FEd^N+dZd+~jq>>5#Pxck?J!Ty@)>NMYR&ocXG9qHot19Awto4~|d~K!Xbb;$8$;YpNX$dDs|iOTVuzXy&KXSjo97K?oj;=URO`=34Kve|IzEJozGkw zX4ZZ;uibI$obzSn_Xb>gHU8_C?dy}{d|F>tzR8XFZtH@Jp8I~-Kec+Z;mKiPOay{Z zo`|HElOOsY#LpW;la_eDjKUArN?xC1oSVBB%`To=rWoJcH~IMmXO;oBBw7Pj~kRHln88PlB}xuY-IYn5~Zx} zlP@p`B8wo=KRTjuPyRLa`M?o5Ul6`Qp3mo9-&^7(I)|keV@p4e`qk6rr>kzxRhO^) zJ$O#0+%MNTeSc6wUOPQE`^T&C`IFOw7mgTJ<`x{3oV?#N*8Ac%c3dN8ZHl;P`>#>A zH`hO{7JKi$x%XJ+k&cnN#F#OH!_S>FzpQMicKUI9Y{+GbybQ-GMDh1j5$So;=EX3M zePWpI{Ruq8&Fysh!ic=EA^Q&(7S}P<&>_H`cjMNu=myW#`^(&p*}G2q`nacXwRGsT zRSeAYc!wKr3xDa@u};tCFL564SoK5X&WQn!9`kwLZ+2Y_J-~A=JQ z=IEx3jSkMHZy9kX+U`#)D&OeUN=aOEXZg~72?0$Mckg@U6nX-8Ut33*moojI9@ls6 zNzNNKZZ;^>4nLzD&s~u+JTE;#a+jHa*X@YSYhFEI}*=X zzvo3BtXlzl?KVsvI5Dy4R#bLT$4hY(=s4E_D5e)UbTGZ9d@3q@WIT}JvZh>A=vb-w zxmhC@lvLL@o@!BT&k9NOg|QhkZZFPE^m+m@wmq-g-jueuBBv@PPyP6@9b>s@^4vnh ziRzcEj`YBnS?#IItI*l5N3+|STiVlRt-Igy8}7ktoR7pVi)oUU-n=~6Lp<)m+(=sb z7e1#-i#(9hE3F%r;{Rze4J*!`x%g;6E=4qt}4&2Y$wXV7(Lm}c;Qrwj|vGrF2nVBZ{m{^U?Q6>q!D_mgh*DXxwMBh zwKS?i#-+tEgpg1hLZm3dGId0B<}|T1Gfm2p(E@lBj)`qjX;mZ|)ud9YGuS3B4VP&o z1e?4!4}&ynCy73dOA9hjP~(MB)DVr1pt5{nKng)H)sN+i`7vaW+|P%KKnMy#7zo1v z%wi)j8^)+z7mY_D?>IV{oGs#qcF7}uacL=fy_OAvMx)W!i27=D$so*Pu|Nm`5dY$bw--`8GzEBmX57 zYg#KD(irS0oknJm61sO(&}Z^3rU+t=QKOS(fSsOcaTtzF%9d+%DqPRQwOXY@ikrO) zBw;vU@01k~IViQ5zzW4|9J5146X|_NyC;+iYoX>*@Ll>&iSL4&mniG?{JsJ@#Re($ zeMfu6nx8wi5?3d4X(m8M$Z>;GPqVp}z95}zu`SiS9=$@ZB>KwypeuP>p4p#;LN;F` zHJG2Wa6ZpqP{`PHNKS`kC<3r>2mXfwO2f2DT{6XL?p#4JLOS%2(>Q~o$Y2dwx`@O4QUGJrV_ag51uK%05 zD7}##p(evQBN?~Z3LXDS@iENo&UnD&?Vn^k5Ht#kf>1F8Plr$uL}$e8*m>4?8a(Aa z9im_8I6ZaJsv#%FdJbAxyii;|Ka2?4Ha}33nYwf1_~42G;)*fx06%wVAw755=s-{D z$Ae?!Lre3hl$AacoCmo%mkeAG*-gV;&v?{7|30n?$Aw)SXM@NBz1 zA_I)l5i*68jPj^3(%U03R7pl|T@4b$`Lj47bU0K1^I1qRg!m!B7{m_>MT7WsmLQ1E z;FBjqarjI>79SO$3>FlGp>!rIBv=rNzycQL7lQf0!MsE;TrJm7F-sYER|*2gAW)>y zknI;OHQmvQOoEEhEqaa@5ppnY7EDE{aH6%kB@(*-hX7nKOAUd z4$)qgB(=+zEZgLmzy2OIAN;r>bQwg+S$y?=o!6mM6EK1{BC+8BHUAzXzTGAM`k)CX z4-RUWKUiL~=xY1uU14`y7YGV(EDBrY>el8ESiHj}ezyOW$(NE6;|}eLTrlBLt!Sjr z;kGqGh9Ss*_&U8wN64KhFpN1- M+}#C{WSRf}0L@dr@&Et; literal 0 HcmV?d00001 diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index bd40938c9..00107411a 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -4,6 +4,7 @@ import strands from strands import Agent from strands.models import BedrockModel +from strands.types.content import ContentBlock @pytest.fixture @@ -27,12 +28,20 @@ def non_streaming_model(): @pytest.fixture def streaming_agent(streaming_model, system_prompt): - return Agent(model=streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) + return Agent( + model=streaming_model, + system_prompt=system_prompt, + load_tools_from_directory=False, + ) @pytest.fixture def non_streaming_agent(non_streaming_model, system_prompt): - return Agent(model=non_streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) + return Agent( + model=non_streaming_model, + system_prompt=system_prompt, + load_tools_from_directory=False, + ) @pytest.fixture @@ -184,6 +193,42 @@ def test_invoke_multi_modal_input(streaming_agent, yellow_img): assert "yellow" in text +def test_document_citations(non_streaming_agent, letter_pdf): + content: list[ContentBlock] = [ + { + "document": { + "name": "letter to shareholders", + "source": {"bytes": letter_pdf}, + "citations": {"enabled": True}, + "context": "This is a letter to shareholders", + "format": "pdf", + }, + }, + {"text": "What does the document say about artificial intelligence? Use citations to back up your answer."}, + ] + non_streaming_agent(content) + + assert any("citationsContent" in content for content in non_streaming_agent.messages[-1]["content"]) + + +def test_document_citations_streaming(streaming_agent, letter_pdf): + content: list[ContentBlock] = [ + { + "document": { + "name": "letter to shareholders", + "source": {"bytes": letter_pdf}, + "citations": {"enabled": True}, + "context": "This is a letter to shareholders", + "format": "pdf", + }, + }, + {"text": "What does the document say about artificial intelligence? Use citations to back up your answer."}, + ] + streaming_agent(content) + + assert any("citationsContent" in content for content in streaming_agent.messages[-1]["content"]) + + def test_structured_output_multi_modal_input(streaming_agent, yellow_img, yellow_color): content = [ {"text": "Is this image red, blue, or yellow?"}, diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index bf5668349..66c5fe9ad 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -2,8 +2,8 @@ import pytest -from src.strands.agent import AgentResult from strands import Agent, tool +from strands.agent import AgentResult from strands.models.bedrock import BedrockModel from strands.types.exceptions import MaxTokensReachedException From 94b41b4ae676f85d5b91e241389fa69ee17b54a5 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 28 Aug 2025 13:20:30 -0400 Subject: [PATCH 061/104] feat: Enable hooks for MultiAgents (#760) It's been a customer ask and we don't have a pressing need to keep it restricted. The primary concern is that because agent's state is manipulated between invocations (state is reset) hooks designed for a single agent may not work for multi-agents. With documentation, we can guide folks to be aware of what happens rather than restricting it outright. --------- Co-authored-by: Mackenzie Zastrow --- src/strands/multiagent/graph.py | 4 --- src/strands/multiagent/swarm.py | 4 --- tests/fixtures/mock_hook_provider.py | 45 ++++++++++++++++++++++++-- tests/strands/multiagent/test_graph.py | 18 ----------- tests/strands/multiagent/test_swarm.py | 16 +-------- tests_integ/test_multiagent_graph.py | 40 +++++++++++++++++++---- tests_integ/test_multiagent_swarm.py | 34 ++++++++++++++++--- 7 files changed, 106 insertions(+), 55 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 9aee260b1..081193b10 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -201,10 +201,6 @@ def _validate_node_executor( if executor._session_manager is not None: raise ValueError("Session persistence is not supported for Graph agents yet.") - # Check for callbacks - if executor.hooks.has_callbacks(): - raise ValueError("Agent callbacks are not supported for Graph agents yet.") - class GraphBuilder: """Builder pattern for constructing graphs.""" diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index a96c92de8..d730d5156 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -318,10 +318,6 @@ def _validate_swarm(self, nodes: list[Agent]) -> None: if node._session_manager is not None: raise ValueError("Session persistence is not supported for Swarm agents yet.") - # Check for callbacks - if node.hooks.has_callbacks(): - raise ValueError("Agent callbacks are not supported for Swarm agents yet.") - def _inject_swarm_tools(self) -> None: """Add swarm coordination tools to each agent.""" # Create tool functions with proper closures diff --git a/tests/fixtures/mock_hook_provider.py b/tests/fixtures/mock_hook_provider.py index 8d7e93253..6bf7b8c77 100644 --- a/tests/fixtures/mock_hook_provider.py +++ b/tests/fixtures/mock_hook_provider.py @@ -1,13 +1,44 @@ -from typing import Iterator, Tuple, Type +from typing import Iterator, Literal, Tuple, Type -from strands.hooks import HookEvent, HookProvider, HookRegistry +from strands import Agent +from strands.experimental.hooks import ( + AfterModelInvocationEvent, + AfterToolInvocationEvent, + BeforeModelInvocationEvent, + BeforeToolInvocationEvent, +) +from strands.hooks import ( + AfterInvocationEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + HookEvent, + HookProvider, + HookRegistry, + MessageAddedEvent, +) class MockHookProvider(HookProvider): - def __init__(self, event_types: list[Type]): + def __init__(self, event_types: list[Type] | Literal["all"]): + if event_types == "all": + event_types = [ + AgentInitializedEvent, + BeforeInvocationEvent, + AfterInvocationEvent, + AfterToolInvocationEvent, + BeforeToolInvocationEvent, + BeforeModelInvocationEvent, + AfterModelInvocationEvent, + MessageAddedEvent, + ] + self.events_received = [] self.events_types = event_types + @property + def event_types_received(self): + return [type(event) for event in self.events_received] + def get_events(self) -> Tuple[int, Iterator[HookEvent]]: return len(self.events_received), iter(self.events_received) @@ -17,3 +48,11 @@ def register_hooks(self, registry: HookRegistry) -> None: def add_event(self, event: HookEvent) -> None: self.events_received.append(event) + + def extract_for(self, agent: Agent) -> "MockHookProvider": + """Extracts a hook provider for the given agent, including the events that were fired for that agent. + + Convenience method when sharing a hook provider between multiple agents.""" + child_provider = MockHookProvider(self.events_types) + child_provider.events_received = [event for event in self.events_received if event.agent == agent] + return child_provider diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index c60361da8..9977c54cd 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -873,15 +873,6 @@ class TestHookProvider(HookProvider): def register_hooks(self, registry, **kwargs): registry.add_callback(AgentInitializedEvent, lambda e: None) - agent_with_hooks = create_mock_agent("agent_with_hooks") - agent_with_hooks._session_manager = None - agent_with_hooks.hooks = HookRegistry() - agent_with_hooks.hooks.add_hook(TestHookProvider()) - - builder = GraphBuilder() - with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"): - builder.add_node(agent_with_hooks) - # Test validation in Graph constructor (when nodes are passed directly) # Test with session manager in Graph constructor node_with_session = GraphNode("node_with_session", agent_with_session) @@ -892,15 +883,6 @@ def register_hooks(self, registry, **kwargs): entry_points=set(), ) - # Test with callbacks in Graph constructor - node_with_hooks = GraphNode("node_with_hooks", agent_with_hooks) - with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"): - Graph( - nodes={"node_with_hooks": node_with_hooks}, - edges=set(), - entry_points=set(), - ) - @pytest.mark.asyncio async def test_controlled_cyclic_execution(): diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 91b677fa4..74f89241f 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -5,8 +5,7 @@ from strands.agent import Agent, AgentResult from strands.agent.state import AgentState -from strands.hooks import AgentInitializedEvent -from strands.hooks.registry import HookProvider, HookRegistry +from strands.hooks.registry import HookRegistry from strands.multiagent.base import Status from strands.multiagent.swarm import SharedContext, Swarm, SwarmNode, SwarmResult, SwarmState from strands.session.session_manager import SessionManager @@ -470,16 +469,3 @@ def test_swarm_validate_unsupported_features(): with pytest.raises(ValueError, match="Session persistence is not supported for Swarm agents yet"): Swarm([agent_with_session]) - - # Test with callbacks (should fail) - class TestHookProvider(HookProvider): - def register_hooks(self, registry, **kwargs): - registry.add_callback(AgentInitializedEvent, lambda e: None) - - agent_with_hooks = create_mock_agent("agent_with_hooks") - agent_with_hooks._session_manager = None - agent_with_hooks.hooks = HookRegistry() - agent_with_hooks.hooks.add_hook(TestHookProvider()) - - with pytest.raises(ValueError, match="Agent callbacks are not supported for Swarm agents yet"): - Swarm([agent_with_hooks]) diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index e1f3a2f3f..bc9b0ea8b 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,8 +1,11 @@ import pytest from strands import Agent, tool +from strands.experimental.hooks import AfterModelInvocationEvent, BeforeModelInvocationEvent +from strands.hooks import AfterInvocationEvent, AgentInitializedEvent, BeforeInvocationEvent, MessageAddedEvent from strands.multiagent.graph import GraphBuilder from strands.types.content import ContentBlock +from tests.fixtures.mock_hook_provider import MockHookProvider @tool @@ -18,49 +21,59 @@ def multiply_numbers(x: int, y: int) -> int: @pytest.fixture -def math_agent(): +def hook_provider(): + return MockHookProvider("all") + + +@pytest.fixture +def math_agent(hook_provider): """Create an agent specialized in mathematical operations.""" return Agent( model="us.amazon.nova-pro-v1:0", system_prompt="You are a mathematical assistant. Always provide clear, step-by-step calculations.", + hooks=[hook_provider], tools=[calculate_sum, multiply_numbers], ) @pytest.fixture -def analysis_agent(): +def analysis_agent(hook_provider): """Create an agent specialized in data analysis.""" return Agent( model="us.amazon.nova-pro-v1:0", + hooks=[hook_provider], system_prompt="You are a data analysis expert. Provide insights and interpretations of numerical results.", ) @pytest.fixture -def summary_agent(): +def summary_agent(hook_provider): """Create an agent specialized in summarization.""" return Agent( model="us.amazon.nova-lite-v1:0", + hooks=[hook_provider], system_prompt="You are a summarization expert. Create concise, clear summaries of complex information.", ) @pytest.fixture -def validation_agent(): +def validation_agent(hook_provider): """Create an agent specialized in validation.""" return Agent( model="us.amazon.nova-pro-v1:0", + hooks=[hook_provider], system_prompt="You are a validation expert. Check results for accuracy and completeness.", ) @pytest.fixture -def image_analysis_agent(): +def image_analysis_agent(hook_provider): """Create an agent specialized in image analysis.""" return Agent( + hooks=[hook_provider], system_prompt=( "You are an image analysis expert. Describe what you see in images and provide detailed analysis." - ) + ), ) @@ -149,7 +162,7 @@ def proceed_to_second_summary(state): @pytest.mark.asyncio -async def test_graph_execution_with_image(image_analysis_agent, summary_agent, yellow_img): +async def test_graph_execution_with_image(image_analysis_agent, summary_agent, yellow_img, hook_provider): """Test graph execution with multi-modal image input.""" builder = GraphBuilder() @@ -186,3 +199,16 @@ async def test_graph_execution_with_image(image_analysis_agent, summary_agent, y # Verify both nodes completed assert "image_analyzer" in result.results assert "summarizer" in result.results + + expected_hook_events = [ + AgentInitializedEvent, + BeforeInvocationEvent, + MessageAddedEvent, + BeforeModelInvocationEvent, + AfterModelInvocationEvent, + MessageAddedEvent, + AfterInvocationEvent, + ] + + assert hook_provider.extract_for(image_analysis_agent).event_types_received == expected_hook_events + assert hook_provider.extract_for(summary_agent).event_types_received == expected_hook_events diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index 6fe5700aa..76860f687 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -1,8 +1,16 @@ import pytest from strands import Agent, tool +from strands.experimental.hooks import ( + AfterModelInvocationEvent, + AfterToolInvocationEvent, + BeforeModelInvocationEvent, + BeforeToolInvocationEvent, +) +from strands.hooks import AfterInvocationEvent, BeforeInvocationEvent, MessageAddedEvent from strands.multiagent.swarm import Swarm from strands.types.content import ContentBlock +from tests.fixtures.mock_hook_provider import MockHookProvider @tool @@ -22,7 +30,12 @@ def calculate(expression: str) -> str: @pytest.fixture -def researcher_agent(): +def hook_provider(): + return MockHookProvider("all") + + +@pytest.fixture +def researcher_agent(hook_provider): """Create an agent specialized in research.""" return Agent( name="researcher", @@ -30,12 +43,13 @@ def researcher_agent(): "You are a research specialist who excels at finding information. When you need to perform calculations or" " format documents, hand off to the appropriate specialist." ), + hooks=[hook_provider], tools=[web_search], ) @pytest.fixture -def analyst_agent(): +def analyst_agent(hook_provider): """Create an agent specialized in data analysis.""" return Agent( name="analyst", @@ -43,15 +57,17 @@ def analyst_agent(): "You are a data analyst who excels at calculations and numerical analysis. When you need" " research or document formatting, hand off to the appropriate specialist." ), + hooks=[hook_provider], tools=[calculate], ) @pytest.fixture -def writer_agent(): +def writer_agent(hook_provider): """Create an agent specialized in writing and formatting.""" return Agent( name="writer", + hooks=[hook_provider], system_prompt=( "You are a professional writer who excels at formatting and presenting information. When you need research" " or calculations, hand off to the appropriate specialist." @@ -59,7 +75,7 @@ def writer_agent(): ) -def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent): +def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent, hook_provider): """Test swarm execution with string input.""" # Create the swarm swarm = Swarm([researcher_agent, analyst_agent, writer_agent]) @@ -82,6 +98,16 @@ def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_age # Verify agent history - at least one agent should have been used assert len(result.node_history) > 0 + # Just ensure that hooks are emitted; actual content is not verified + researcher_hooks = hook_provider.extract_for(researcher_agent).event_types_received + assert BeforeInvocationEvent in researcher_hooks + assert MessageAddedEvent in researcher_hooks + assert BeforeModelInvocationEvent in researcher_hooks + assert BeforeToolInvocationEvent in researcher_hooks + assert AfterToolInvocationEvent in researcher_hooks + assert AfterModelInvocationEvent in researcher_hooks + assert AfterInvocationEvent in researcher_hooks + @pytest.mark.asyncio async def test_swarm_execution_with_image(researcher_agent, analyst_agent, writer_agent, yellow_img): From b008cf506b7081171c5d4efe1e18e1c356488a9b Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 28 Aug 2025 15:54:13 -0400 Subject: [PATCH 062/104] Add invocation_state to ToolContext (#761) Addresses issue #579, #750 --------- Co-authored-by: Mackenzie Zastrow --- src/strands/tools/decorator.py | 10 +++++++--- src/strands/types/tools.py | 6 +++++- tests/strands/tools/test_decorator.py | 15 +++++++++++++-- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 75abac9ed..2ce6d946f 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -265,10 +265,13 @@ def inject_special_parameters( Args: validated_input: The validated input parameters (modified in place). tool_use: The tool use request containing tool invocation details. - invocation_state: Context for the tool invocation, including agent state. + invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(), + agent.invoke_async(), etc.). """ if self._context_param and self._context_param in self.signature.parameters: - tool_context = ToolContext(tool_use=tool_use, agent=invocation_state["agent"]) + tool_context = ToolContext( + tool_use=tool_use, agent=invocation_state["agent"], invocation_state=invocation_state + ) validated_input[self._context_param] = tool_context # Inject agent if requested (backward compatibility) @@ -433,7 +436,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw Args: tool_use: The tool use specification from the Agent. - invocation_state: Context for the tool invocation, including agent state. + invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(), + agent.invoke_async(), etc.). **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index bb7c874f6..1e0f4b841 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -132,6 +132,8 @@ class ToolContext: tool_use: The complete ToolUse object containing tool invocation details. agent: The Agent instance executing this tool, providing access to conversation history, model configuration, and other agent state. + invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(), + agent.invoke_async(), etc.). Note: This class is intended to be instantiated by the SDK. Direct construction by users @@ -140,6 +142,7 @@ class ToolContext: tool_use: ToolUse agent: "Agent" + invocation_state: dict[str, Any] ToolChoice = Union[ @@ -246,7 +249,8 @@ def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Args: tool_use: The tool use request containing tool ID and parameters. - invocation_state: Context for the tool invocation, including agent state. + invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(), + agent.invoke_async(), etc.). **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index e490c7bb0..02e7eb445 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -2,6 +2,7 @@ Tests for the function-based tool decorator pattern. """ +from asyncio import Queue from typing import Any, Dict, Optional, Union from unittest.mock import MagicMock @@ -1039,7 +1040,7 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None] assert "NoneType: None" in result["content"][0]["text"] -async def _run_context_injection_test(context_tool: AgentTool): +async def _run_context_injection_test(context_tool: AgentTool, additional_context=None): """Common test logic for context injection tests.""" tool: AgentTool = context_tool generator = tool.stream( @@ -1052,6 +1053,7 @@ async def _run_context_injection_test(context_tool: AgentTool): }, invocation_state={ "agent": Agent(name="test_agent"), + **(additional_context or {}), }, ) tool_results = [value async for value in generator] @@ -1074,6 +1076,8 @@ async def _run_context_injection_test(context_tool: AgentTool): async def test_tool_context_injection_default(): """Test that ToolContext is properly injected with default parameter name (tool_context).""" + value_to_pass = Queue() # a complex value that is not serializable + @strands.tool(context=True) def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict: """Tool that uses ToolContext to access tool_use_id.""" @@ -1081,6 +1085,8 @@ def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict: tool_name = tool_context.tool_use["name"] agent_from_tool_context = tool_context.agent + assert tool_context.invocation_state["test_reference"] is value_to_pass + return { "status": "success", "content": [ @@ -1090,7 +1096,12 @@ def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict: ], } - await _run_context_injection_test(context_tool) + await _run_context_injection_test( + context_tool, + { + "test_reference": value_to_pass, + }, + ) @pytest.mark.asyncio From ae9d5ad0b0faf904a62b4d3e5fe84069f3ec9f38 Mon Sep 17 00:00:00 2001 From: Dom Bavaro Date: Fri, 29 Aug 2025 10:08:55 -0400 Subject: [PATCH 063/104] feat(models): Add VPC endpoint support to BedrockModel class (#502) Co-authored-by: Dean Schmigelski --- src/strands/models/bedrock.py | 3 +++ tests/strands/models/test_bedrock.py | 33 +++++++++++++++++++++------- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 0fe332a47..c44717041 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -103,6 +103,7 @@ def __init__( boto_session: Optional[boto3.Session] = None, boto_client_config: Optional[BotocoreConfig] = None, region_name: Optional[str] = None, + endpoint_url: Optional[str] = None, **model_config: Unpack[BedrockConfig], ): """Initialize provider instance. @@ -112,6 +113,7 @@ def __init__( boto_client_config: Configuration to use when creating the Bedrock-Runtime Boto Client. region_name: AWS region to use for the Bedrock service. Defaults to the AWS_REGION environment variable if set, or "us-west-2" if not set. + endpoint_url: Custom endpoint URL for VPC endpoints (PrivateLink) **model_config: Configuration options for the Bedrock model. """ if region_name and boto_session: @@ -143,6 +145,7 @@ def __init__( self.client = session.client( service_name="bedrock-runtime", config=client_config, + endpoint_url=endpoint_url, region_name=resolved_region, ) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 09e508845..f1a2250e4 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -129,7 +129,7 @@ def test__init__with_default_region(session_cls, mock_client_method): with unittest.mock.patch.object(os, "environ", {}): BedrockModel() session_cls.return_value.client.assert_called_with( - region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY + region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=None ) @@ -139,14 +139,14 @@ def test__init__with_session_region(session_cls, mock_client_method): BedrockModel() - mock_client_method.assert_called_with(region_name="eu-blah-1", config=ANY, service_name=ANY) + mock_client_method.assert_called_with(region_name="eu-blah-1", config=ANY, service_name=ANY, endpoint_url=None) def test__init__with_custom_region(mock_client_method): """Test that BedrockModel uses the provided region.""" custom_region = "us-east-1" BedrockModel(region_name=custom_region) - mock_client_method.assert_called_with(region_name=custom_region, config=ANY, service_name=ANY) + mock_client_method.assert_called_with(region_name=custom_region, config=ANY, service_name=ANY, endpoint_url=None) def test__init__with_default_environment_variable_region(mock_client_method): @@ -154,7 +154,7 @@ def test__init__with_default_environment_variable_region(mock_client_method): with unittest.mock.patch.object(os, "environ", {"AWS_REGION": "eu-west-2"}): BedrockModel() - mock_client_method.assert_called_with(region_name="eu-west-2", config=ANY, service_name=ANY) + mock_client_method.assert_called_with(region_name="eu-west-2", config=ANY, service_name=ANY, endpoint_url=None) def test__init__region_precedence(mock_client_method, session_cls): @@ -164,21 +164,38 @@ def test__init__region_precedence(mock_client_method, session_cls): # specifying a region always wins out BedrockModel(region_name="us-specified-1") - mock_client_method.assert_called_with(region_name="us-specified-1", config=ANY, service_name=ANY) + mock_client_method.assert_called_with( + region_name="us-specified-1", config=ANY, service_name=ANY, endpoint_url=None + ) # other-wise uses the session's BedrockModel() - mock_client_method.assert_called_with(region_name="us-session-1", config=ANY, service_name=ANY) + mock_client_method.assert_called_with( + region_name="us-session-1", config=ANY, service_name=ANY, endpoint_url=None + ) # environment variable next session_cls.return_value.region_name = None BedrockModel() - mock_client_method.assert_called_with(region_name="us-environment-1", config=ANY, service_name=ANY) + mock_client_method.assert_called_with( + region_name="us-environment-1", config=ANY, service_name=ANY, endpoint_url=None + ) mock_os_environ.pop("AWS_REGION") session_cls.return_value.region_name = None # No session region BedrockModel() - mock_client_method.assert_called_with(region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY) + mock_client_method.assert_called_with( + region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=None + ) + + +def test__init__with_endpoint_url(mock_client_method): + """Test that BedrockModel uses the provided endpoint_url for VPC endpoints.""" + custom_endpoint = "https://vpce-12345-abcde.bedrock-runtime.us-west-2.vpce.amazonaws.com" + BedrockModel(endpoint_url=custom_endpoint) + mock_client_method.assert_called_with( + region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=custom_endpoint + ) def test__init__with_region_and_session_raises_value_error(): From 7a5caad1e8d9d77315e09894241261bb75f1892f Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Sat, 30 Aug 2025 01:36:49 +0800 Subject: [PATCH 064/104] fix: fix stop reason for bedrock model when stop_reason (#767) * fix: fix stop reason for bedrock model when stop_reason is end_turn in tool use response. * change logger info to warning, optimize if condition * fix: add unit tests --------- Co-authored-by: Jack Yuan --- src/strands/models/bedrock.py | 31 ++++++++++++++++-- tests/strands/models/test_bedrock.py | 47 ++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c44717041..ba4828c1a 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -435,6 +435,8 @@ def _stream( logger.debug("got response from model") if streaming: response = self.client.converse_stream(**request) + # Track tool use events to fix stopReason for streaming responses + has_tool_use = False for chunk in response["stream"]: if ( "metadata" in chunk @@ -446,7 +448,24 @@ def _stream( for event in self._generate_redaction_events(): callback(event) - callback(chunk) + # Track if we see tool use events + if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"): + has_tool_use = True + + # Fix stopReason for streaming responses that contain tool use + if ( + has_tool_use + and "messageStop" in chunk + and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn" + ): + # Create corrected chunk with tool_use stopReason + modified_chunk = chunk.copy() + modified_chunk["messageStop"] = message_stop.copy() + modified_chunk["messageStop"]["stopReason"] = "tool_use" + logger.warning("Override stop reason from end_turn to tool_use") + callback(modified_chunk) + else: + callback(chunk) else: response = self.client.converse(**request) @@ -582,9 +601,17 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera yield {"contentBlockStop": {}} # Yield messageStop event + # Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side + current_stop_reason = response["stopReason"] + if current_stop_reason == "end_turn": + message_content = response["output"]["message"]["content"] + if any("toolUse" in content for content in message_content): + current_stop_reason = "tool_use" + logger.warning("Override stop reason from end_turn to tool_use") + yield { "messageStop": { - "stopReason": response["stopReason"], + "stopReason": current_stop_reason, "additionalModelResponseFields": response.get("additionalModelResponseFields"), } } diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index f1a2250e4..2f44c2e65 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1227,6 +1227,53 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist): assert "finished streaming response from model" in log_text +@pytest.mark.asyncio +async def test_stream_stop_reason_override_streaming(bedrock_client, model, messages, alist): + """Test that stopReason is overridden from end_turn to tool_use in streaming mode when tool use is detected.""" + bedrock_client.converse_stream.return_value = { + "stream": [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test_tool"}}}}, + {"contentBlockDelta": {"delta": {"test": {"input": '{"param": "value"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + } + + response = model.stream(messages) + events = await alist(response) + + # Find the messageStop event + message_stop_event = next(event for event in events if "messageStop" in event) + + # Verify stopReason was overridden to tool_use + assert message_stop_event["messageStop"]["stopReason"] == "tool_use" + + +@pytest.mark.asyncio +async def test_stream_stop_reason_override_non_streaming(bedrock_client, alist, messages): + """Test that stopReason is overridden from end_turn to tool_use in non-streaming mode when tool use is detected.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"param": "value"}}}], + } + }, + "stopReason": "end_turn", + } + + model = BedrockModel(model_id="test-model", streaming=False) + response = model.stream(messages) + events = await alist(response) + + # Find the messageStop event + message_stop_event = next(event for event in events if "messageStop" in event) + + # Verify stopReason was overridden to tool_use + assert message_stop_event["messageStop"]["stopReason"] == "tool_use" + + def test_format_request_cleans_tool_result_content_blocks(model, model_id): """Test that format_request cleans toolResult blocks by removing extra fields.""" messages = [ From cb4b7fb83ab34f1e41368f4988274e771646e3f8 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Fri, 29 Aug 2025 15:24:54 -0400 Subject: [PATCH 065/104] fix: Fix tool result message event (#771) Expand the Unit Tests for the yielded event to verify actual tool calls - previous to this, the events were not being emitted because the test was bailing out due to mocked guard rails. To better test the situation, we now have a much more extensive test for the successful tool call Co-authored-by: Mackenzie Zastrow --- src/strands/event_loop/event_loop.py | 2 +- .../strands/agent/hooks/test_agent_events.py | 329 ++++++++++++++++-- 2 files changed, 304 insertions(+), 27 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index a99ecc8a6..5d5085101 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -361,7 +361,7 @@ async def _handle_tool_execution( agent.messages.append(tool_result_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) - yield ToolResultMessageEvent(message=message) + yield ToolResultMessageEvent(message=tool_result_message) if cycle_span: tracer = get_tracer() diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index d63dd97d4..04b832259 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -1,3 +1,4 @@ +import asyncio import unittest.mock from unittest.mock import ANY, MagicMock, call @@ -11,49 +12,333 @@ from tests.fixtures.mocked_model_provider import MockedModelProvider +@strands.tool +def normal_tool(agent: Agent): + return f"Done with synchronous {agent.name}!" + + +@strands.tool +async def async_tool(agent: Agent): + await asyncio.sleep(0.1) + return f"Done with asynchronous {agent.name}!" + + +@strands.tool +async def streaming_tool(): + await asyncio.sleep(0.2) + yield {"tool_streaming": True} + yield "Final result" + + @pytest.fixture def mock_time(): with unittest.mock.patch.object(strands.event_loop.event_loop, "time") as mock: yield mock -@pytest.mark.asyncio -async def test_stream_async_e2e(alist, mock_time): - @strands.tool - def fake_tool(agent: Agent): - return "Done!" +any_props = { + "agent": ANY, + "event_loop_cycle_id": ANY, + "event_loop_cycle_span": ANY, + "event_loop_cycle_trace": ANY, + "request_state": {}, +} + +@pytest.mark.asyncio +async def test_stream_e2e_success(alist): mock_provider = MockedModelProvider( [ - {"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}, - {"role": "assistant", "content": [{"text": "Okay invoking tool!"}]}, { "role": "assistant", - "content": [{"toolUse": {"name": "fake_tool", "toolUseId": "123", "input": {}}}], + "content": [ + {"text": "Okay invoking normal tool"}, + {"toolUse": {"name": "normal_tool", "toolUseId": "123", "input": {}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Invoking async tool"}, + {"toolUse": {"name": "async_tool", "toolUseId": "1234", "input": {}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Invoking streaming tool"}, + {"toolUse": {"name": "streaming_tool", "toolUseId": "12345", "input": {}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "I invoked the tools!"}, + ], }, - {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, ] ) + + mock_callback = unittest.mock.Mock() + agent = Agent(model=mock_provider, tools=[async_tool, normal_tool, streaming_tool], callback_handler=mock_callback) + + stream = agent.stream_async("Do the stuff", arg1=1013) + + tool_config = { + "toolChoice": {"auto": {}}, + "tools": [ + { + "toolSpec": { + "description": "async_tool", + "inputSchema": {"json": {"properties": {}, "required": [], "type": "object"}}, + "name": "async_tool", + } + }, + { + "toolSpec": { + "description": "normal_tool", + "inputSchema": {"json": {"properties": {}, "required": [], "type": "object"}}, + "name": "normal_tool", + } + }, + { + "toolSpec": { + "description": "streaming_tool", + "inputSchema": {"json": {"properties": {}, "required": [], "type": "object"}}, + "name": "streaming_tool", + } + }, + ], + } + + tru_events = await alist(stream) + exp_events = [ + # Cycle 1: Initialize and invoke normal_tool + {"arg1": 1013, "init_event_loop": True}, + {"start": True}, + {"start_event_loop": True}, + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "Okay invoking normal tool"}}}}, + { + **any_props, + "arg1": 1013, + "data": "Okay invoking normal tool", + "delta": {"text": "Okay invoking normal tool"}, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"contentBlockStart": {"start": {"toolUse": {"name": "normal_tool", "toolUseId": "123"}}}}}, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}}, + { + **any_props, + "arg1": 1013, + "current_tool_use": {"input": {}, "name": "normal_tool", "toolUseId": "123"}, + "delta": {"toolUse": {"input": "{}"}}, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "tool_use"}}}, + { + "message": { + "content": [ + {"text": "Okay invoking normal tool"}, + {"toolUse": {"input": {}, "name": "normal_tool", "toolUseId": "123"}}, + ], + "role": "assistant", + } + }, + { + "message": { + "content": [ + { + "toolResult": { + "content": [{"text": "Done with synchronous Strands Agents!"}], + "status": "success", + "toolUseId": "123", + } + }, + ], + "role": "user", + } + }, + # Cycle 2: Invoke async_tool + {"start": True}, + {"start": True}, + {"start_event_loop": True}, + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "Invoking async tool"}}}}, + { + **any_props, + "arg1": 1013, + "data": "Invoking async tool", + "delta": {"text": "Invoking async tool"}, + "event_loop_parent_cycle_id": ANY, + "messages": ANY, + "model": ANY, + "system_prompt": None, + "tool_config": tool_config, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"contentBlockStart": {"start": {"toolUse": {"name": "async_tool", "toolUseId": "1234"}}}}}, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}}, + { + **any_props, + "arg1": 1013, + "current_tool_use": {"input": {}, "name": "async_tool", "toolUseId": "1234"}, + "delta": {"toolUse": {"input": "{}"}}, + "event_loop_parent_cycle_id": ANY, + "messages": ANY, + "model": ANY, + "system_prompt": None, + "tool_config": tool_config, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "tool_use"}}}, + { + "message": { + "content": [ + {"text": "Invoking async tool"}, + {"toolUse": {"input": {}, "name": "async_tool", "toolUseId": "1234"}}, + ], + "role": "assistant", + } + }, + { + "message": { + "content": [ + { + "toolResult": { + "content": [{"text": "Done with asynchronous Strands Agents!"}], + "status": "success", + "toolUseId": "1234", + } + }, + ], + "role": "user", + } + }, + # Cycle 3: Invoke streaming_tool + {"start": True}, + {"start": True}, + {"start_event_loop": True}, + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "Invoking streaming tool"}}}}, + { + **any_props, + "arg1": 1013, + "data": "Invoking streaming tool", + "delta": {"text": "Invoking streaming tool"}, + "event_loop_parent_cycle_id": ANY, + "messages": ANY, + "model": ANY, + "system_prompt": None, + "tool_config": tool_config, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"contentBlockStart": {"start": {"toolUse": {"name": "streaming_tool", "toolUseId": "12345"}}}}}, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}}, + { + **any_props, + "arg1": 1013, + "current_tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}, + "delta": {"toolUse": {"input": "{}"}}, + "event_loop_parent_cycle_id": ANY, + "messages": ANY, + "model": ANY, + "system_prompt": None, + "tool_config": tool_config, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "tool_use"}}}, + { + "message": { + "content": [ + {"text": "Invoking streaming tool"}, + {"toolUse": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}}, + ], + "role": "assistant", + } + }, + { + "message": { + "content": [ + { + "toolResult": { + # TODO update this text when we get tool streaming implemented; right now this + # TODO is of the form '' + "content": [{"text": ANY}], + "status": "success", + "toolUseId": "12345", + } + }, + ], + "role": "user", + } + }, + # Cycle 4: Final response + {"start": True}, + {"start": True}, + {"start_event_loop": True}, + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "I invoked the tools!"}}}}, + { + **any_props, + "arg1": 1013, + "data": "I invoked the tools!", + "delta": {"text": "I invoked the tools!"}, + "event_loop_parent_cycle_id": ANY, + "messages": ANY, + "model": ANY, + "system_prompt": None, + "tool_config": tool_config, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "end_turn"}}}, + {"message": {"content": [{"text": "I invoked the tools!"}], "role": "assistant"}}, + { + "result": AgentResult( + stop_reason="end_turn", + message={"content": [{"text": "I invoked the tools!"}], "role": "assistant"}, + metrics=ANY, + state={}, + ) + }, + ] + assert tru_events == exp_events + + exp_calls = [call(**event) for event in exp_events] + act_calls = mock_callback.call_args_list + assert act_calls == exp_calls + + # Ensure that all events coming out of the agent are *not* typed events + typed_events = [event for event in tru_events if isinstance(event, TypedEvent)] + assert typed_events == [] + + +@pytest.mark.asyncio +async def test_stream_e2e_throttle_and_redact(alist, mock_time): model = MagicMock() model.stream.side_effect = [ ModelThrottledException("ThrottlingException | ConverseStream"), ModelThrottledException("ThrottlingException | ConverseStream"), - mock_provider.stream([]), + MockedModelProvider( + [ + {"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}, + ] + ).stream([]), ] mock_callback = unittest.mock.Mock() - agent = Agent(model=model, tools=[fake_tool], callback_handler=mock_callback) + agent = Agent(model=model, tools=[normal_tool], callback_handler=mock_callback) stream = agent.stream_async("Do the stuff", arg1=1013) # Base object with common properties throttle_props = { - "agent": ANY, - "event_loop_cycle_id": ANY, - "event_loop_cycle_span": ANY, - "event_loop_cycle_trace": ANY, + **any_props, "arg1": 1013, - "request_state": {}, } tru_events = await alist(stream) @@ -68,14 +353,10 @@ def fake_tool(agent: Agent): {"event": {"contentBlockStart": {"start": {}}}}, {"event": {"contentBlockDelta": {"delta": {"text": "INPUT BLOCKED!"}}}}, { - "agent": ANY, + **any_props, "arg1": 1013, "data": "INPUT BLOCKED!", "delta": {"text": "INPUT BLOCKED!"}, - "event_loop_cycle_id": ANY, - "event_loop_cycle_span": ANY, - "event_loop_cycle_trace": ANY, - "request_state": {}, }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "guardrail_intervened"}}}, @@ -128,12 +409,8 @@ async def test_event_loop_cycle_text_response_throttling_early_end( # Base object with common properties common_props = { - "agent": ANY, - "event_loop_cycle_id": ANY, - "event_loop_cycle_span": ANY, - "event_loop_cycle_trace": ANY, + **any_props, "arg1": 1013, - "request_state": {}, } exp_events = [ From e7d95d6ad2c13dbde6257afbba96802822f70b26 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Sat, 30 Aug 2025 05:15:31 +0800 Subject: [PATCH 066/104] fix: fix loading tools with same tool name (#772) * fix: fix loading tools with same tool name * simplify if condition --------- Co-authored-by: Jack Yuan --- src/strands/tools/registry.py | 7 ++++++ tests/strands/tools/test_registry.py | 36 ++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index fd395ae77..6bb76f560 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -190,6 +190,13 @@ def register_tool(self, tool: AgentTool) -> None: tool.is_dynamic, ) + # Check duplicate tool name, throw on duplicate tool names except if hot_reloading is enabled + if tool.tool_name in self.registry and not tool.supports_hot_reload: + raise ValueError( + f"Tool name '{tool.tool_name}' already exists. Cannot register tools with exact same name." + ) + + # Check for normalized name conflicts (- vs _) if self.registry.get(tool.tool_name) is None: normalized_name = tool.tool_name.replace("-", "_") diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 66494c987..ca3cded4c 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -120,3 +120,39 @@ def function() -> str: "tool_f", ] assert tru_tool_names == exp_tool_names + + +def test_register_tool_duplicate_name_without_hot_reload(): + """Test that registering a tool with duplicate name raises ValueError when hot reload is not supported.""" + tool_1 = PythonAgentTool(tool_name="duplicate_tool", tool_spec=MagicMock(), tool_func=lambda: None) + tool_2 = PythonAgentTool(tool_name="duplicate_tool", tool_spec=MagicMock(), tool_func=lambda: None) + + tool_registry = ToolRegistry() + tool_registry.register_tool(tool_1) + + with pytest.raises( + ValueError, match="Tool name 'duplicate_tool' already exists. Cannot register tools with exact same name." + ): + tool_registry.register_tool(tool_2) + + +def test_register_tool_duplicate_name_with_hot_reload(): + """Test that registering a tool with duplicate name succeeds when hot reload is supported.""" + # Create mock tools with hot reload support + tool_1 = MagicMock(spec=PythonAgentTool) + tool_1.tool_name = "hot_reload_tool" + tool_1.supports_hot_reload = True + tool_1.is_dynamic = False + + tool_2 = MagicMock(spec=PythonAgentTool) + tool_2.tool_name = "hot_reload_tool" + tool_2.supports_hot_reload = True + tool_2.is_dynamic = False + + tool_registry = ToolRegistry() + tool_registry.register_tool(tool_1) + + tool_registry.register_tool(tool_2) + + # Verify the second tool replaced the first + assert tool_registry.registry["hot_reload_tool"] == tool_2 From 237e1881323dbfa909688a149d65dccc6ee8bd40 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 3 Sep 2025 09:56:49 -0400 Subject: [PATCH 067/104] fix: don't emit ToolStream events for non generator functions (#773) Our current implementation of AgentTool.stream() has a problem that we don't differentiate between intermediate streaming events and the final ToolResult events. Our only contract is that the last event *must be* be the tool result that is passed to the LLM. Our switch to Typed Events (#755) pushes us in the right direction but for backwards compatibility we can't update the signature of `AgentTool.stream()` (nor have we exposed externally TypedEvents yet). That means that if we implemented tool-streaming today, then callers would see non-generator functions yielding both a `ToolStreamEvent` and `ToolResultEvent` even though they're not actually streaming responses. To avoid the odd behavior noted above, we'll special-case SDK-defined functions by allowing them to emit `ToolStreamEvent` and `ToolResultEvent` types directly (bypassing our normal wrapping), since they have the knowledge of when tools are actually generators or not. There's no observable difference in behavior to callers (this is all internal behavior), but this means that when we switch the flip for Tool Streaming, non-generator tools will **not** emit ToolStreamEvents - at least for AgentTool implementations that are in the SDK. Co-authored-by: Mackenzie Zastrow --- src/strands/tools/decorator.py | 80 ++++--- src/strands/tools/executors/_executor.py | 15 +- src/strands/tools/mcp/mcp_agent_tool.py | 3 +- src/strands/tools/mcp/mcp_types.py | 2 +- src/strands/tools/tools.py | 5 +- .../tools/executors/test_concurrent.py | 6 +- .../strands/tools/executors/test_executor.py | 73 +++++- .../tools/executors/test_sequential.py | 6 +- .../strands/tools/mcp/test_mcp_agent_tool.py | 3 +- tests/strands/tools/test_decorator.py | 217 ++++++++++-------- tests/strands/tools/test_tools.py | 3 +- 11 files changed, 271 insertions(+), 142 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 2ce6d946f..8b218dfa1 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -53,6 +53,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: Type, TypeVar, Union, + cast, get_type_hints, overload, ) @@ -61,7 +62,8 @@ def my_tool(param1: str, param2: int = 42) -> dict: from pydantic import BaseModel, Field, create_model from typing_extensions import override -from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolSpec, ToolUse +from ..types._events import ToolResultEvent, ToolStreamEvent +from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolResult, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -454,43 +456,67 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw # Inject special framework-provided parameters self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state) - # "Too few arguments" expected, hence the type ignore - if inspect.iscoroutinefunction(self._tool_func): + # Note: "Too few arguments" expected for the _tool_func calls, hence the type ignore + + # Async-generators, yield streaming events and final tool result + if inspect.isasyncgenfunction(self._tool_func): + sub_events = self._tool_func(**validated_input) # type: ignore + async for sub_event in sub_events: + yield ToolStreamEvent(tool_use, sub_event) + + # The last event is the result + yield self._wrap_tool_result(tool_use_id, sub_event) + + # Async functions, yield only the result + elif inspect.iscoroutinefunction(self._tool_func): result = await self._tool_func(**validated_input) # type: ignore - else: - result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore + yield self._wrap_tool_result(tool_use_id, result) - # FORMAT THE RESULT for Strands Agent - if isinstance(result, dict) and "status" in result and "content" in result: - # Result is already in the expected format, just add toolUseId - result["toolUseId"] = tool_use_id - yield result + # Other functions, yield only the result else: - # Wrap any other return value in the standard format - # Always include at least one content item for consistency - yield { - "toolUseId": tool_use_id, - "status": "success", - "content": [{"text": str(result)}], - } + result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore + yield self._wrap_tool_result(tool_use_id, result) except ValueError as e: # Special handling for validation errors error_msg = str(e) - yield { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error: {error_msg}"}], - } + yield self._wrap_tool_result( + tool_use_id, + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error: {error_msg}"}], + }, + ) except Exception as e: # Return error result with exception details for any other error error_type = type(e).__name__ error_msg = str(e) - yield { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error: {error_type} - {error_msg}"}], - } + yield self._wrap_tool_result( + tool_use_id, + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error: {error_type} - {error_msg}"}], + }, + ) + + def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent: + # FORMAT THE RESULT for Strands Agent + if isinstance(result, dict) and "status" in result and "content" in result: + # Result is already in the expected format, just add toolUseId + result["toolUseId"] = tool_use_d + return ToolResultEvent(cast(ToolResult, result)) + else: + # Wrap any other return value in the standard format + # Always include at least one content item for consistency + return ToolResultEvent( + { + "toolUseId": tool_use_d, + "status": "success", + "content": [{"text": str(result)}], + } + ) @property def supports_hot_reload(self) -> bool: diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 701a3bac0..5354991c3 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -119,7 +119,20 @@ async def _stream( return async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): - yield ToolStreamEvent(tool_use, event) + # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() + # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. + # In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent + # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in + # ToolStreamEvent and the last even is just the result + + if isinstance(event, ToolResultEvent): + # below the last "event" must point to the tool_result + event = event.tool_result + break + elif isinstance(event, ToolStreamEvent): + yield event + else: + yield ToolStreamEvent(tool_use, event) result = cast(ToolResult, event) diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index f9c8d6061..f15bb1718 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -11,6 +11,7 @@ from mcp.types import Tool as MCPTool from typing_extensions import override +from ...types._events import ToolResultEvent from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse if TYPE_CHECKING: @@ -96,4 +97,4 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw name=self.tool_name, arguments=tool_use["input"], ) - yield result + yield ToolResultEvent(result) diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py index 5fafed5dc..66eda08ae 100644 --- a/src/strands/tools/mcp/mcp_types.py +++ b/src/strands/tools/mcp/mcp_types.py @@ -9,7 +9,7 @@ from mcp.shared.message import SessionMessage from typing_extensions import NotRequired -from strands.types.tools import ToolResult +from ...types.tools import ToolResult """ MCPTransport defines the interface for MCP transport implementations. This abstracts diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 465063095..9e1c0e608 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -12,6 +12,7 @@ from typing_extensions import override +from ..types._events import ToolResultEvent from ..types.tools import AgentTool, ToolFunc, ToolGenerator, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -211,7 +212,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw """ if inspect.iscoroutinefunction(self._tool_func): result = await self._tool_func(tool_use, **invocation_state) + yield ToolResultEvent(result) else: result = await asyncio.to_thread(self._tool_func, tool_use, **invocation_state) - - yield result + yield ToolResultEvent(result) diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index 140537add..f7fc64b25 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -1,7 +1,7 @@ import pytest from strands.tools.executors import ConcurrentToolExecutor -from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolResultEvent from strands.types.tools import ToolUse @@ -22,13 +22,11 @@ async def test_concurrent_executor_execute( tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id) exp_events = [ - ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), - ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ] assert tru_events == exp_events tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId")) - exp_results = [exp_events[1].tool_result, exp_events[3].tool_result] + exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 56caa950a..903a11e5a 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -1,4 +1,5 @@ import unittest.mock +from unittest.mock import MagicMock import pytest @@ -39,7 +40,6 @@ async def test_executor_stream_yields_result( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ] assert tru_events == exp_events @@ -67,6 +67,76 @@ async def test_executor_stream_yields_result( assert tru_hook_events == exp_hook_events +@pytest.mark.asyncio +async def test_executor_stream_wraps_results( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator +): + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + weather_tool.stream = MagicMock() + weather_tool.stream.return_value = agenerator( + ["value 1", {"nested": True}, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}] + ) + + tru_events = await alist(stream) + exp_events = [ + ToolStreamEvent(tool_use, "value 1"), + ToolStreamEvent(tool_use, {"nested": True}), + ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ] + assert tru_events == exp_events + + +@pytest.mark.asyncio +async def test_executor_stream_passes_through_typed_events( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator +): + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + weather_tool.stream = MagicMock() + event_1 = ToolStreamEvent(tool_use, "value 1") + event_2 = ToolStreamEvent(tool_use, {"nested": True}) + event_3 = ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}) + weather_tool.stream.return_value = agenerator( + [ + event_1, + event_2, + event_3, + ] + ) + + tru_events = await alist(stream) + assert tru_events[0] is event_1 + assert tru_events[1] is event_2 + + # ToolResults are not passed through directly, they're unwrapped then wraped again + assert tru_events[2] == event_3 + + +@pytest.mark.asyncio +async def test_executor_stream_wraps_stream_events_if_no_result( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator +): + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + weather_tool.stream = MagicMock() + last_event = ToolStreamEvent(tool_use, "value 1") + # Only ToolResultEvent can be the last value; all others are wrapped in ToolResultEvent + weather_tool.stream.return_value = agenerator( + [ + last_event, + ] + ) + + tru_events = await alist(stream) + exp_events = [last_event, ToolResultEvent(last_event)] + assert tru_events == exp_events + + @pytest.mark.asyncio async def test_executor_stream_yields_tool_error( executor, agent, tool_results, invocation_state, hook_events, exception_tool, alist @@ -129,7 +199,6 @@ async def test_executor_stream_with_trace( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ] assert tru_events == exp_events diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py index d4e98223e..37e098142 100644 --- a/tests/strands/tools/executors/test_sequential.py +++ b/tests/strands/tools/executors/test_sequential.py @@ -1,7 +1,7 @@ import pytest from strands.tools.executors import SequentialToolExecutor -from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolResultEvent @pytest.fixture @@ -21,13 +21,11 @@ async def test_sequential_executor_execute( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), - ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[1].tool_result, exp_events[3].tool_result] + exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results diff --git a/tests/strands/tools/mcp/test_mcp_agent_tool.py b/tests/strands/tools/mcp/test_mcp_agent_tool.py index 874006683..1c025f5f2 100644 --- a/tests/strands/tools/mcp/test_mcp_agent_tool.py +++ b/tests/strands/tools/mcp/test_mcp_agent_tool.py @@ -4,6 +4,7 @@ from mcp.types import Tool as MCPTool from strands.tools.mcp import MCPAgentTool, MCPClient +from strands.types._events import ToolResultEvent @pytest.fixture @@ -62,7 +63,7 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist): tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}} tru_events = await alist(mcp_agent_tool.stream(tool_use, {})) - exp_events = [mock_mcp_client.call_tool_async.return_value] + exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)] assert tru_events == exp_events mock_mcp_client.call_tool_async.assert_called_once_with( diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 02e7eb445..a13c2833e 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -10,6 +10,7 @@ import strands from strands import Agent +from strands.types._events import ToolResultEvent from strands.types.tools import AgentTool, ToolContext, ToolUse @@ -117,7 +118,7 @@ async def test_stream(identity_tool, alist): stream = identity_tool.stream({"toolUseId": "t1", "input": {"a": 2}}, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "t1", "status": "success", "content": [{"text": "2"}]}] + exp_events = [ToolResultEvent({"toolUseId": "t1", "status": "success", "content": [{"text": "2"}]})] assert tru_events == exp_events @@ -131,7 +132,9 @@ def identity(a: int, agent: dict = None): stream = identity.stream({"input": {"a": 2}}, {"agent": {"state": 1}}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}) + ] assert tru_events == exp_events @@ -180,7 +183,9 @@ def test_tool(param1: str, param2: int) -> str: stream = test_tool.stream(tool_use, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}) + ] assert tru_events == exp_events # Make sure these are set properly @@ -229,7 +234,9 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: stream = test_tool.stream(tool_use, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello"}]}) + ] assert tru_events == exp_events # Test with both params @@ -237,7 +244,9 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: stream = test_tool.stream(tool_use, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}) + ] @pytest.mark.asyncio @@ -256,8 +265,8 @@ def test_tool(required: str) -> str: stream = test_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "validation error for test_tooltool\nrequired\n" in result["content"][0]["text"].lower(), ( + assert result["tool_result"]["status"] == "error" + assert "validation error for test_tooltool\nrequired\n" in result["tool_result"]["content"][0]["text"].lower(), ( "Validation error should indicate which argument is missing" ) @@ -266,8 +275,8 @@ def test_tool(required: str) -> str: stream = test_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "test error" in result["content"][0]["text"].lower(), ( + assert result["tool_result"]["status"] == "error" + assert "test error" in result["tool_result"]["content"][0]["text"].lower(), ( "Runtime error should contain the original error message" ) @@ -313,14 +322,14 @@ def test_tool(param: str, agent=None) -> str: stream = test_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["content"][0]["text"] == "Param: test" + assert result["tool_result"]["content"][0]["text"] == "Param: test" # Test with agent stream = test_tool.stream(tool_use, {"agent": mock_agent}) result = (await alist(stream))[-1] - assert "Agent:" in result["content"][0]["text"] - assert "test" in result["content"][0]["text"] + assert "Agent:" in result["tool_result"]["content"][0]["text"] + assert "test" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -350,23 +359,23 @@ def none_return_tool(param: str) -> None: stream = dict_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "Result: test" - assert result["toolUseId"] == "test-id" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "Result: test" + assert result["tool_result"]["toolUseId"] == "test-id" # Test the string return - should wrap in standard format stream = string_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "Result: test" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "Result: test" # Test None return - should still create valid ToolResult with "None" text stream = none_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" @pytest.mark.asyncio @@ -403,7 +412,7 @@ def test_method(self, param: str) -> str: stream = instance.test_method.stream(tool_use, {}) result = (await alist(stream))[-1] - assert "Test: tool-value" in result["content"][0]["text"] + assert "Test: tool-value" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -422,7 +431,9 @@ class MyThing: ... stream = instance.field.stream({"toolUseId": "test-id", "input": {"param": "example"}}, {}) result2 = (await alist(stream))[-1] - assert result2 == {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + assert result2 == ToolResultEvent( + {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + ) @pytest.mark.asyncio @@ -444,7 +455,9 @@ def test_method(param: str) -> str: stream = instance.field.stream({"toolUseId": "test-id", "input": {"param": "example"}}, {}) result2 = (await alist(stream))[-1] - assert result2 == {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + assert result2 == ToolResultEvent( + {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + ) @pytest.mark.asyncio @@ -474,14 +487,14 @@ def tool_with_defaults(required: str, optional: str = "default", number: int = 4 stream = tool_with_defaults.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["content"][0]["text"] == "hello default 42" + assert result["tool_result"]["content"][0]["text"] == "hello default 42" # Call with some but not all optional parameters tool_use = {"toolUseId": "test-id", "input": {"required": "hello", "number": 100}} stream = tool_with_defaults.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["content"][0]["text"] == "hello default 100" + assert result["tool_result"]["content"][0]["text"] == "hello default 100" @pytest.mark.asyncio @@ -496,14 +509,15 @@ def test_tool(required: str) -> str: # Test with completely empty tool use stream = test_tool.stream({}, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "unknown" in result["toolUseId"] + print(result) + assert result["tool_result"]["status"] == "error" + assert "unknown" in result["tool_result"]["toolUseId"] # Test with missing input stream = test_tool.stream({"toolUseId": "test-id"}, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "test-id" in result["toolUseId"] + assert result["tool_result"]["status"] == "error" + assert "test-id" in result["tool_result"]["toolUseId"] @pytest.mark.asyncio @@ -529,8 +543,8 @@ def add_numbers(a: int, b: int) -> int: stream = add_numbers.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "5" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "5" @pytest.mark.asyncio @@ -565,8 +579,8 @@ def multi_default_tool( stream = multi_default_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "hello, default_str, 42, True, 3.14" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "hello, default_str, 42, True, 3.14" in result["tool_result"]["content"][0]["text"] # Test calling with some optional parameters tool_use = { @@ -576,7 +590,7 @@ def multi_default_tool( stream = multi_default_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert "hello, default_str, 100, True, 2.718" in result["content"][0]["text"] + assert "hello, default_str, 100, True, 2.718" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -603,8 +617,8 @@ def int_return_tool(param: str) -> int: stream = int_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "42" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "42" # Test with return that doesn't match declared type # Note: This should still work because Python doesn't enforce return types at runtime @@ -613,16 +627,16 @@ def int_return_tool(param: str) -> int: stream = int_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "not an int" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "not an int" # Test with None return from a non-None return type tool_use = {"toolUseId": "test-id", "input": {"param": "none"}} stream = int_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" # Define tool with Union return type @strands.tool @@ -644,22 +658,25 @@ def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: stream = union_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "{'key': 'value'}" in result["content"][0]["text"] or '{"key": "value"}' in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert ( + "{'key': 'value'}" in result["tool_result"]["content"][0]["text"] + or '{"key": "value"}' in result["tool_result"]["content"][0]["text"] + ) tool_use = {"toolUseId": "test-id", "input": {"param": "str"}} stream = union_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "string result" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "string result" tool_use = {"toolUseId": "test-id", "input": {"param": "none"}} stream = union_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" @pytest.mark.asyncio @@ -682,8 +699,8 @@ def no_params_tool() -> str: stream = no_params_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "Success - no parameters needed" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "Success - no parameters needed" # Test direct call direct_result = no_params_tool() @@ -711,8 +728,8 @@ def complex_type_tool(config: Dict[str, Any]) -> str: stream = complex_type_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "Got config with 3 keys" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "Got config with 3 keys" in result["tool_result"]["content"][0]["text"] # Direct call direct_result = complex_type_tool(nested_dict) @@ -742,12 +759,12 @@ def custom_result_tool(param: str) -> Dict[str, Any]: # The wrapper should preserve our format and just add the toolUseId result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["toolUseId"] == "custom-id" - assert len(result["content"]) == 2 - assert result["content"][0]["text"] == "First line: test" - assert result["content"][1]["text"] == "Second line" - assert result["content"][1]["type"] == "markdown" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["toolUseId"] == "custom-id" + assert len(result["tool_result"]["content"]) == 2 + assert result["tool_result"]["content"][0]["text"] == "First line: test" + assert result["tool_result"]["content"][1]["text"] == "Second line" + assert result["tool_result"]["content"][1]["type"] == "markdown" def test_docstring_parsing(): @@ -816,8 +833,8 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: stream = validation_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "int_param" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "error" + assert "int_param" in result["tool_result"]["content"][0]["text"] # Test missing required parameter tool_use = { @@ -831,8 +848,8 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: stream = validation_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "int_param" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "error" + assert "int_param" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -855,16 +872,16 @@ def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: stream = edge_case_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" # Test with empty dict tool_use = {"toolUseId": "test-id", "input": {"param": {}}} stream = edge_case_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "{}" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "{}" # Test with a complex nested dictionary nested_dict = {"key1": {"nested": [1, 2, 3]}, "key2": None} @@ -872,9 +889,9 @@ def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: stream = edge_case_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "key1" in result["content"][0]["text"] - assert "nested" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "key1" in result["tool_result"]["content"][0]["text"] + assert "nested" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -922,8 +939,8 @@ def test_method(self): stream = instance.test_method.stream({"toolUseId": "test-id", "input": {"param": "direct"}}, {}) direct_result = (await alist(stream))[-1] - assert direct_result["status"] == "success" - assert direct_result["content"][0]["text"] == "Method Got: direct" + assert direct_result["tool_result"]["status"] == "success" + assert direct_result["tool_result"]["content"][0]["text"] == "Method Got: direct" # Create a standalone function to test regular function calls @strands.tool @@ -944,8 +961,8 @@ def standalone_tool(p1: str, p2: str = "default") -> str: stream = standalone_tool.stream({"toolUseId": "test-id", "input": {"p1": "value1"}}, {}) tool_use_result = (await alist(stream))[-1] - assert tool_use_result["status"] == "success" - assert tool_use_result["content"][0]["text"] == "Standalone: value1, default" + assert tool_use_result["tool_result"]["status"] == "success" + assert tool_use_result["tool_result"]["content"][0]["text"] == "Standalone: value1, default" @pytest.mark.asyncio @@ -976,9 +993,9 @@ def failing_tool(param: str) -> str: stream = failing_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" + assert result["tool_result"]["status"] == "error" - error_message = result["content"][0]["text"] + error_message = result["tool_result"]["content"][0]["text"] # Check that error type is included if error_type == "value_error": @@ -1011,33 +1028,33 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None] stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "list: [1, 2, 3]" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "list: [1, 2, 3]" in result["tool_result"]["content"][0]["text"] # Test with a dict tool_use = {"toolUseId": "test-id", "input": {"union_param": {"key": "value"}}} stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "dict:" in result["content"][0]["text"] - assert "key" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "dict:" in result["tool_result"]["content"][0]["text"] + assert "key" in result["tool_result"]["content"][0]["text"] # Test with a string tool_use = {"toolUseId": "test-id", "input": {"union_param": "test_string"}} stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "str: test_string" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "str: test_string" in result["tool_result"]["content"][0]["text"] # Test with None tool_use = {"toolUseId": "test-id", "input": {"union_param": None}} stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "NoneType: None" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "NoneType: None" in result["tool_result"]["content"][0]["text"] async def _run_context_injection_test(context_tool: AgentTool, additional_context=None): @@ -1061,15 +1078,17 @@ async def _run_context_injection_test(context_tool: AgentTool, additional_contex assert len(tool_results) == 1 tool_result = tool_results[0] - assert tool_result == { - "status": "success", - "content": [ - {"text": "Tool 'context_tool' (ID: test-id)"}, - {"text": "injected agent 'test_agent' processed: some_message"}, - {"text": "context agent 'test_agent'"}, - ], - "toolUseId": "test-id", - } + assert tool_result == ToolResultEvent( + { + "status": "success", + "content": [ + {"text": "Tool 'context_tool' (ID: test-id)"}, + {"text": "injected agent 'test_agent' processed: some_message"}, + {"text": "context agent 'test_agent'"}, + ], + "toolUseId": "test-id", + } + ) @pytest.mark.asyncio @@ -1164,9 +1183,9 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> dict: tool_result = tool_results[0] # Should get a validation error because tool_context is required but not provided - assert tool_result["status"] == "error" - assert "tool_context" in tool_result["content"][0]["text"].lower() - assert "validation" in tool_result["content"][0]["text"].lower() + assert tool_result["tool_result"]["status"] == "error" + assert "tool_context" in tool_result["tool_result"]["content"][0]["text"].lower() + assert "validation" in tool_result["tool_result"]["content"][0]["text"].lower() @pytest.mark.asyncio @@ -1196,8 +1215,10 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str: tool_result = tool_results[0] # Should succeed with the string parameter - assert tool_result == { - "status": "success", - "content": [{"text": "success"}], - "toolUseId": "test-id-2", - } + assert tool_result == ToolResultEvent( + { + "status": "success", + "content": [{"text": "success"}], + "toolUseId": "test-id-2", + } + ) diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index 240c24717..b305a1a90 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -9,6 +9,7 @@ validate_tool_use, validate_tool_use_name, ) +from strands.types._events import ToolResultEvent from strands.types.tools import ToolUse @@ -506,5 +507,5 @@ async def test_stream(identity_tool, alist): stream = identity_tool.stream({"tool_use": 1}, {"a": 2}) tru_events = await alist(stream) - exp_events = [({"tool_use": 1}, 2)] + exp_events = [ToolResultEvent(({"tool_use": 1}, 2))] assert tru_events == exp_events From 4dee33b32cc10be9ab6d75b80198119ee3009417 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 3 Sep 2025 10:28:26 -0400 Subject: [PATCH 068/104] fix(tests): adjust test_bedrock_guardrails to account for async behavior (#785) --- src/strands/tools/registry.py | 6 ++-- tests_integ/test_bedrock_guardrails.py | 45 ++++++++++++++++++++++---- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 6bb76f560..471472a64 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -192,9 +192,9 @@ def register_tool(self, tool: AgentTool) -> None: # Check duplicate tool name, throw on duplicate tool names except if hot_reloading is enabled if tool.tool_name in self.registry and not tool.supports_hot_reload: - raise ValueError( - f"Tool name '{tool.tool_name}' already exists. Cannot register tools with exact same name." - ) + raise ValueError( + f"Tool name '{tool.tool_name}' already exists. Cannot register tools with exact same name." + ) # Check for normalized name conflicts (- vs _) if self.registry.get(tool.tool_name) is None: diff --git a/tests_integ/test_bedrock_guardrails.py b/tests_integ/test_bedrock_guardrails.py index 4683918cb..e25bf3cca 100644 --- a/tests_integ/test_bedrock_guardrails.py +++ b/tests_integ/test_bedrock_guardrails.py @@ -138,9 +138,25 @@ def test_guardrail_output_intervention(boto_session, bedrock_guardrail, processi response1 = agent("Say the word.") response2 = agent("Hello!") assert response1.stop_reason == "guardrail_intervened" - assert BLOCKED_OUTPUT in str(response1) - assert response2.stop_reason != "guardrail_intervened" - assert BLOCKED_OUTPUT not in str(response2) + + """ + In async streaming: The buffering is non-blocking. + Tokens are streamed while Guardrails processes the buffered content in the background. + This means the response may be returned before Guardrails has finished processing. + As a result, we cannot guarantee that the REDACT_MESSAGE is in the response + """ + if processing_mode == "sync": + assert BLOCKED_OUTPUT in str(response1) + assert response2.stop_reason != "guardrail_intervened" + assert BLOCKED_OUTPUT not in str(response2) + else: + cactus_returned_in_response1_blocked_by_input_guardrail = BLOCKED_INPUT in str(response2) + cactus_blocked_in_response1_allows_next_response = ( + BLOCKED_OUTPUT not in str(response2) and response2.stop_reason != "guardrail_intervened" + ) + assert ( + cactus_returned_in_response1_blocked_by_input_guardrail or cactus_blocked_in_response1_allows_next_response + ) @pytest.mark.parametrize("processing_mode", ["sync", "async"]) @@ -164,10 +180,27 @@ def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processi response1 = agent("Say the word.") response2 = agent("Hello!") + assert response1.stop_reason == "guardrail_intervened" - assert REDACT_MESSAGE in str(response1) - assert response2.stop_reason != "guardrail_intervened" - assert REDACT_MESSAGE not in str(response2) + + """ + In async streaming: The buffering is non-blocking. + Tokens are streamed while Guardrails processes the buffered content in the background. + This means the response may be returned before Guardrails has finished processing. + As a result, we cannot guarantee that the REDACT_MESSAGE is in the response + """ + if processing_mode == "sync": + assert REDACT_MESSAGE in str(response1) + assert response2.stop_reason != "guardrail_intervened" + assert REDACT_MESSAGE not in str(response2) + else: + cactus_returned_in_response1_blocked_by_input_guardrail = BLOCKED_INPUT in str(response2) + cactus_blocked_in_response1_allows_next_response = ( + REDACT_MESSAGE not in str(response2) and response2.stop_reason != "guardrail_intervened" + ) + assert ( + cactus_returned_in_response1_blocked_by_input_guardrail or cactus_blocked_in_response1_allows_next_response + ) def test_guardrail_input_intervention_properly_redacts_in_session(boto_session, bedrock_guardrail, temp_dir): From 2db52266a5b66ea08f692288d64c5871f57fd968 Mon Sep 17 00:00:00 2001 From: Deepesh Dhakal Date: Thu, 4 Sep 2025 00:37:08 +0900 Subject: [PATCH 069/104] fix(doc): replace invalid Hook names in doc comment with BeforeInvocationEvent & AfterInvocationEvent (#782) Co-authored-by: deepyes02 --- src/strands/hooks/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 77be9d64e..b98e95a6e 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -8,17 +8,17 @@ Example Usage: ```python from strands.hooks import HookProvider, HookRegistry - from strands.hooks.events import StartRequestEvent, EndRequestEvent + from strands.hooks.events import BeforeInvocationEvent, AfterInvocationEvent class LoggingHooks(HookProvider): def register_hooks(self, registry: HookRegistry) -> None: - registry.add_callback(StartRequestEvent, self.log_start) - registry.add_callback(EndRequestEvent, self.log_end) + registry.add_callback(BeforeInvocationEvent, self.log_start) + registry.add_callback(AfterInvocationEvent, self.log_end) - def log_start(self, event: StartRequestEvent) -> None: + def log_start(self, event: BeforeInvocationEvent) -> None: print(f"Request started for {event.agent.name}") - def log_end(self, event: EndRequestEvent) -> None: + def log_end(self, event: AfterInvocationEvent) -> None: print(f"Request completed for {event.agent.name}") # Use with agent From 1e6d12d755066d21ce27e693f67f7dcc2577aa33 Mon Sep 17 00:00:00 2001 From: mehtarac Date: Thu, 4 Sep 2025 09:13:14 -0700 Subject: [PATCH 070/104] fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider (#686) * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider * fix: Remove status field from toolResult for non-claude 3 models in Bedrock model provider --- src/strands/models/bedrock.py | 38 ++++++++++++--- tests/strands/models/test_bedrock.py | 73 ++++++++++++++++++++++++++-- 2 files changed, 102 insertions(+), 9 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index ba4828c1a..b1628d817 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -37,6 +37,11 @@ "too many total text bytes", ] +# Models that should include tool result status (include_tool_result_status = True) +_MODELS_INCLUDE_STATUS = [ + "anthropic.claude", +] + T = TypeVar("T", bound=BaseModel) @@ -71,6 +76,8 @@ class BedrockConfig(TypedDict, total=False): guardrail_redact_output_message: If a Bedrock Output guardrail triggers, replace output with this message. max_tokens: Maximum number of tokens to generate in the response model_id: The Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0") + include_tool_result_status: Flag to include status field in tool results. + True includes status, False removes status, "auto" determines based on model_id. Defaults to "auto". stop_sequences: List of sequences that will stop generation when encountered streaming: Flag to enable/disable streaming. Defaults to True. temperature: Controls randomness in generation (higher = more random) @@ -92,6 +99,7 @@ class BedrockConfig(TypedDict, total=False): guardrail_redact_output_message: Optional[str] max_tokens: Optional[int] model_id: str + include_tool_result_status: Optional[Literal["auto"] | bool] stop_sequences: Optional[list[str]] streaming: Optional[bool] temperature: Optional[float] @@ -119,7 +127,7 @@ def __init__( if region_name and boto_session: raise ValueError("Cannot specify both `region_name` and `boto_session`.") - self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID) + self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID, include_tool_result_status="auto") self.update_config(**model_config) logger.debug("config=<%s> | initializing", self.config) @@ -169,6 +177,17 @@ def get_config(self) -> BedrockConfig: """ return self.config + def _should_include_tool_result_status(self) -> bool: + """Determine whether to include tool result status based on current config.""" + include_status = self.config.get("include_tool_result_status", "auto") + + if include_status is True: + return True + elif include_status is False: + return False + else: # "auto" + return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS) + def format_request( self, messages: Messages, @@ -282,10 +301,18 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: # Create a new content block with only the cleaned toolResult tool_result: ToolResult = content_block["toolResult"] - # Keep only the required fields for Bedrock - cleaned_tool_result = ToolResult( - content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"] - ) + if self._should_include_tool_result_status(): + # Include status field + cleaned_tool_result = ToolResult( + content=tool_result["content"], + toolUseId=tool_result["toolUseId"], + status=tool_result["status"], + ) + else: + # Remove status field + cleaned_tool_result = ToolResult( # type: ignore[typeddict-item] + toolUseId=tool_result["toolUseId"], content=tool_result["content"] + ) cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result} cleaned_content.append(cleaned_block) @@ -296,7 +323,6 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: # Create new message with cleaned content cleaned_message: Message = Message(content=cleaned_content, role=message["role"]) cleaned_messages.append(cleaned_message) - return cleaned_messages def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2f44c2e65..e0f7879c0 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1275,7 +1275,6 @@ async def test_stream_stop_reason_override_non_streaming(bedrock_client, alist, def test_format_request_cleans_tool_result_content_blocks(model, model_id): - """Test that format_request cleans toolResult blocks by removing extra fields.""" messages = [ { "role": "user", @@ -1295,9 +1294,77 @@ def test_format_request_cleans_tool_result_content_blocks(model, model_id): formatted_request = model.format_request(messages) - # Verify toolResult only contains allowed fields in the formatted request tool_result = formatted_request["messages"][0]["content"][0]["toolResult"] - expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool123", "status": "success"} + expected = {"toolUseId": "tool123", "content": [{"text": "Tool output"}]} assert tool_result == expected assert "extraField" not in tool_result assert "mcpMetadata" not in tool_result + assert "status" not in tool_result + + +def test_format_request_removes_status_field_when_configured(model, model_id): + model.update_config(include_tool_result_status=False) + + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "content": [{"text": "Tool output"}], + "toolUseId": "tool123", + "status": "success", + } + }, + ], + } + ] + + formatted_request = model.format_request(messages) + + tool_result = formatted_request["messages"][0]["content"][0]["toolResult"] + expected = {"toolUseId": "tool123", "content": [{"text": "Tool output"}]} + assert tool_result == expected + assert "status" not in tool_result + + +def test_auto_behavior_anthropic_vs_non_anthropic(bedrock_client): + model_anthropic = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") + assert model_anthropic.get_config()["include_tool_result_status"] == "auto" + + model_non_anthropic = BedrockModel(model_id="amazon.titan-text-v1") + assert model_non_anthropic.get_config()["include_tool_result_status"] == "auto" + + +def test_explicit_boolean_values_preserved(bedrock_client): + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", include_tool_result_status=True) + assert model.get_config()["include_tool_result_status"] is True + + model2 = BedrockModel(model_id="amazon.titan-text-v1", include_tool_result_status=False) + assert model2.get_config()["include_tool_result_status"] is False + """Test that format_request keeps status field by default for anthropic.claude models.""" + # Default model is anthropic.claude, so should keep status + model = BedrockModel() + + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "content": [{"text": "Tool output"}], + "toolUseId": "tool123", + "status": "success", + } + }, + ], + } + ] + + formatted_request = model.format_request(messages) + + # Verify toolResult contains status field by default + tool_result = formatted_request["messages"][0]["content"][0]["toolResult"] + expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool123", "status": "success"} + assert tool_result == expected + assert "status" in tool_result From ed3386823a58b15d0faa407ebfe5c1a36ff76d75 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Fri, 5 Sep 2025 01:58:47 +0800 Subject: [PATCH 071/104] fix: filter 'SDK_UNKNOWN_MEMBER' from response content (#798) Co-authored-by: Jack Yuan --- src/strands/models/bedrock.py | 15 ++++++++++++++- tests/strands/models/test_bedrock.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index b1628d817..8a6d5116f 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -180,7 +180,7 @@ def get_config(self) -> BedrockConfig: def _should_include_tool_result_status(self) -> bool: """Determine whether to include tool result status based on current config.""" include_status = self.config.get("include_tool_result_status", "auto") - + if include_status is True: return True elif include_status is False: @@ -275,6 +275,7 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: """Format messages for Bedrock API compatibility. This function ensures messages conform to Bedrock's expected format by: + - Filtering out SDK_UNKNOWN_MEMBER content blocks - Cleaning tool result content blocks by removing additional fields that may be useful for retaining information in hooks but would cause Bedrock validation exceptions when presented with unexpected fields @@ -292,11 +293,17 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html """ cleaned_messages = [] + filtered_unknown_members = False for message in messages: cleaned_content: list[ContentBlock] = [] for content_block in message["content"]: + # Filter out SDK_UNKNOWN_MEMBER content blocks + if "SDK_UNKNOWN_MEMBER" in content_block: + filtered_unknown_members = True + continue + if "toolResult" in content_block: # Create a new content block with only the cleaned toolResult tool_result: ToolResult = content_block["toolResult"] @@ -323,6 +330,12 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: # Create new message with cleaned content cleaned_message: Message = Message(content=cleaned_content, role=message["role"]) cleaned_messages.append(cleaned_message) + + if filtered_unknown_members: + logger.warning( + "Filtered out SDK_UNKNOWN_MEMBER content blocks from messages, consider upgrading boto3 version" + ) + return cleaned_messages def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index e0f7879c0..13918b6ea 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1331,7 +1331,7 @@ def test_format_request_removes_status_field_when_configured(model, model_id): def test_auto_behavior_anthropic_vs_non_anthropic(bedrock_client): model_anthropic = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") assert model_anthropic.get_config()["include_tool_result_status"] == "auto" - + model_non_anthropic = BedrockModel(model_id="amazon.titan-text-v1") assert model_non_anthropic.get_config()["include_tool_result_status"] == "auto" @@ -1339,7 +1339,7 @@ def test_auto_behavior_anthropic_vs_non_anthropic(bedrock_client): def test_explicit_boolean_values_preserved(bedrock_client): model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", include_tool_result_status=True) assert model.get_config()["include_tool_result_status"] is True - + model2 = BedrockModel(model_id="amazon.titan-text-v1", include_tool_result_status=False) assert model2.get_config()["include_tool_result_status"] is False """Test that format_request keeps status field by default for anthropic.claude models.""" @@ -1368,3 +1368,27 @@ def test_explicit_boolean_values_preserved(bedrock_client): expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool123", "status": "success"} assert tool_result == expected assert "status" in tool_result + + +def test_format_request_filters_sdk_unknown_member_content_blocks(model, model_id, caplog): + """Test that format_request filters out SDK_UNKNOWN_MEMBER content blocks.""" + messages = [ + { + "role": "assistant", + "content": [ + {"text": "Hello"}, + {"SDK_UNKNOWN_MEMBER": {"name": "reasoningContent"}}, + {"text": "World"}, + ], + } + ] + + formatted_request = model.format_request(messages) + + content = formatted_request["messages"][0]["content"] + assert len(content) == 2 + assert content[0] == {"text": "Hello"} + assert content[1] == {"text": "World"} + + for block in content: + assert "SDK_UNKNOWN_MEMBER" not in block From d07629f28645250d2a8a2e06a367751223612543 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 4 Sep 2025 14:39:30 -0400 Subject: [PATCH 072/104] feat: Implement async generator tools (#788) Enable decorated tools to be an async generator, enabling streaming of tool events back to to the caller. --------- Co-authored-by: Mackenzie Zastrow --- src/strands/types/_events.py | 13 +- .../strands/agent/hooks/test_agent_events.py | 22 +-- tests/strands/tools/test_decorator.py | 145 +++++++++++++++++- 3 files changed, 160 insertions(+), 20 deletions(-) diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 1a7f48d4b..ccdab1846 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -275,24 +275,19 @@ def is_callback_event(self) -> bool: class ToolStreamEvent(TypedEvent): """Event emitted when a tool yields sub-events as part of tool execution.""" - def __init__(self, tool_use: ToolUse, tool_sub_event: Any) -> None: + def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None: """Initialize with tool streaming data. Args: tool_use: The tool invocation producing the stream - tool_sub_event: The yielded event from the tool execution + tool_stream_data: The yielded event from the tool execution """ - super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_event": tool_sub_event}) + super().__init__({"tool_stream_event": {"tool_use": tool_use, "data": tool_stream_data}}) @property def tool_use_id(self) -> str: """The toolUseId associated with this stream.""" - return cast(str, cast(ToolUse, self.get("tool_stream_tool_use")).get("toolUseId")) - - @property - @override - def is_callback_event(self) -> bool: - return False + return cast(str, cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId")) class ModelMessageEvent(TypedEvent): diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 04b832259..07f55b724 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -260,18 +260,22 @@ async def test_stream_e2e_success(alist): "role": "assistant", } }, + { + "tool_stream_event": { + "data": {"tool_streaming": True}, + "tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}, + } + }, + { + "tool_stream_event": { + "data": "Final result", + "tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}, + } + }, { "message": { "content": [ - { - "toolResult": { - # TODO update this text when we get tool streaming implemented; right now this - # TODO is of the form '' - "content": [{"text": ANY}], - "status": "success", - "toolUseId": "12345", - } - }, + {"toolResult": {"content": [{"text": "Final result"}], "status": "success", "toolUseId": "12345"}} ], "role": "user", } diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index a13c2833e..5b4b5cdda 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -3,14 +3,14 @@ """ from asyncio import Queue -from typing import Any, Dict, Optional, Union +from typing import Any, AsyncGenerator, Dict, Optional, Union from unittest.mock import MagicMock import pytest import strands from strands import Agent -from strands.types._events import ToolResultEvent +from strands.types._events import ToolResultEvent, ToolStreamEvent from strands.types.tools import AgentTool, ToolContext, ToolUse @@ -1222,3 +1222,144 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str: "toolUseId": "test-id-2", } ) + + +@pytest.mark.asyncio +async def test_tool_async_generator(): + """Test that async generators yield results appropriately.""" + + @strands.tool(context=False) + async def async_generator() -> AsyncGenerator: + """Tool that expects tool_context as a regular string parameter.""" + yield 0 + yield "Value 1" + yield {"nested": "value"} + yield { + "status": "success", + "content": [{"text": "Looks like tool result"}], + "toolUseId": "test-id-2", + } + yield "final result" + + tool: AgentTool = async_generator + tool_use: ToolUse = { + "toolUseId": "test-id-2", + "name": "context_tool", + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, + } + generator = tool.stream( + tool_use=tool_use, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + act_results = [value async for value in generator] + exp_results = [ + ToolStreamEvent(tool_use, 0), + ToolStreamEvent(tool_use, "Value 1"), + ToolStreamEvent(tool_use, {"nested": "value"}), + ToolStreamEvent( + tool_use, + { + "status": "success", + "content": [{"text": "Looks like tool result"}], + "toolUseId": "test-id-2", + }, + ), + ToolStreamEvent(tool_use, "final result"), + ToolResultEvent( + { + "status": "success", + "content": [{"text": "final result"}], + "toolUseId": "test-id-2", + } + ), + ] + + assert act_results == exp_results + + +@pytest.mark.asyncio +async def test_tool_async_generator_exceptions_result_in_error(): + """Test that async generators handle exceptions.""" + + @strands.tool(context=False) + async def async_generator() -> AsyncGenerator: + """Tool that expects tool_context as a regular string parameter.""" + yield 13 + raise ValueError("It's an error!") + + tool: AgentTool = async_generator + tool_use: ToolUse = { + "toolUseId": "test-id-2", + "name": "context_tool", + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, + } + generator = tool.stream( + tool_use=tool_use, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + act_results = [value async for value in generator] + exp_results = [ + ToolStreamEvent(tool_use, 13), + ToolResultEvent( + { + "status": "error", + "content": [{"text": "Error: It's an error!"}], + "toolUseId": "test-id-2", + } + ), + ] + + assert act_results == exp_results + + +@pytest.mark.asyncio +async def test_tool_async_generator_yield_object_result(): + """Test that async generators handle exceptions.""" + + @strands.tool(context=False) + async def async_generator() -> AsyncGenerator: + """Tool that expects tool_context as a regular string parameter.""" + yield 13 + yield { + "status": "success", + "content": [{"text": "final result"}], + "toolUseId": "test-id-2", + } + + tool: AgentTool = async_generator + tool_use: ToolUse = { + "toolUseId": "test-id-2", + "name": "context_tool", + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, + } + generator = tool.stream( + tool_use=tool_use, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + act_results = [value async for value in generator] + exp_results = [ + ToolStreamEvent(tool_use, 13), + ToolStreamEvent( + tool_use, + { + "status": "success", + "content": [{"text": "final result"}], + "toolUseId": "test-id-2", + }, + ), + ToolResultEvent( + { + "status": "success", + "content": [{"text": "final result"}], + "toolUseId": "test-id-2", + } + ), + ] + + assert act_results == exp_results From ec000b82e90872335229cf8656df595e871026fe Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 5 Sep 2025 08:47:43 -0400 Subject: [PATCH 073/104] ci: update openai requirement from <1.100.0 to <1.102.0 (#722) * ci: update openai requirement from <1.100.0 to <1.102.0 Updates the requirements on [openai](https://github.com/openai/openai-python) to permit the latest version. - [Release notes](https://github.com/openai/openai-python/releases) - [Changelog](https://github.com/openai/openai-python/blob/main/CHANGELOG.md) - [Commits](https://github.com/openai/openai-python/compare/v1.68.0...v1.101.0) --- updated-dependencies: - dependency-name: openai dependency-version: 1.101.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Nick Clegg --- pyproject.toml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8a95ba04c..a0be0ddc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,9 +68,8 @@ docs = [ "sphinx-autodoc-typehints>=1.12.0,<2.0.0", ] litellm = [ - "litellm>=1.73.1,<2.0.0", - # https://github.com/BerriAI/litellm/issues/13711 - "openai<1.100.0", + "litellm>=1.75.9,<2.0.0", + "openai>=1.68.0,<1.102.0", ] llamaapi = [ "llama-api-client>=0.1.0,<1.0.0", From d77f08b0bbe4736e3e2031d4cbf52e74263887e2 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Fri, 5 Sep 2025 15:19:27 -0400 Subject: [PATCH 074/104] fix: only add signature to reasoning blocks if signature is provided (#806) * fix: only add signature to reasoning blocks if signature is provided --------- Co-authored-by: Mackenzie Zastrow Co-authored-by: Dean Schmigelski --- src/strands/event_loop/streaming.py | 1 - tests/strands/event_loop/test_streaming.py | 86 +++++++++++++++++++++- 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index efe094e5f..183fe1ec8 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -289,7 +289,6 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[T "text": "", "current_tool_use": {}, "reasoningText": "", - "signature": "", "citationsContent": [], } state["content"] = state["message"]["content"] diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index ce12b4e98..32d1889e5 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -1,10 +1,12 @@ import unittest.mock +from typing import cast import pytest import strands import strands.event_loop -from strands.types._events import TypedEvent +from strands.types._events import ModelStopReason, TypedEvent +from strands.types.content import Message from strands.types.streaming import ( ContentBlockDeltaEvent, ContentBlockStartEvent, @@ -565,6 +567,88 @@ async def test_process_stream(response, exp_events, agenerator, alist): assert non_typed_events == [] +def _get_message_from_event(event: ModelStopReason) -> Message: + return cast(Message, event["stop"][1]) + + +@pytest.mark.asyncio +async def test_process_stream_with_no_signature(agenerator, alist): + response = [ + {"messageStart": {"role": "assistant"}}, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": 'User asks: "Reason about 2+2" so I will do that'}}, + "contentBlockIndex": 0, + } + }, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "."}}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, + { + "contentBlockDelta": { + "delta": {"text": "Sure! Let’s do it"}, + "contentBlockIndex": 1, + } + }, + {"contentBlockStop": {"contentBlockIndex": 1}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 112, "outputTokens": 764, "totalTokens": 876}, + "metrics": {"latencyMs": 2970}, + } + }, + ] + + stream = strands.event_loop.streaming.process_stream(agenerator(response)) + + last_event = cast(ModelStopReason, (await alist(stream))[-1]) + + message = _get_message_from_event(last_event) + + assert "signature" not in message["content"][0]["reasoningContent"]["reasoningText"] + assert message["content"][1]["text"] == "Sure! Let’s do it" + + +@pytest.mark.asyncio +async def test_process_stream_with_signature(agenerator, alist): + response = [ + {"messageStart": {"role": "assistant"}}, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": 'User asks: "Reason about 2+2" so I will do that'}}, + "contentBlockIndex": 0, + } + }, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "."}}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "test-"}}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "signature"}}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, + { + "contentBlockDelta": { + "delta": {"text": "Sure! Let’s do it"}, + "contentBlockIndex": 1, + } + }, + {"contentBlockStop": {"contentBlockIndex": 1}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 112, "outputTokens": 764, "totalTokens": 876}, + "metrics": {"latencyMs": 2970}, + } + }, + ] + + stream = strands.event_loop.streaming.process_stream(agenerator(response)) + + last_event = cast(ModelStopReason, (await alist(stream))[-1]) + + message = _get_message_from_event(last_event) + + assert message["content"][0]["reasoningContent"]["reasoningText"]["signature"] == "test-signature" + assert message["content"][1]["text"] == "Sure! Let’s do it" + + @pytest.mark.asyncio async def test_stream_messages(agenerator, alist): mock_model = unittest.mock.MagicMock() From faeb21aba456a2114acd95a454b58aa51daad670 Mon Sep 17 00:00:00 2001 From: Parham Ghazanfari Date: Mon, 8 Sep 2025 10:48:11 -0400 Subject: [PATCH 075/104] fix: Moved tool_spec retrieval to after the before model invocation callback (#786) Co-authored-by: Parham Ghazanfari --- src/strands/event_loop/event_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 5d5085101..099a524c6 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -132,14 +132,14 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> model_id=model_id, ) with trace_api.use_span(model_invoke_span): - tool_specs = agent.tool_registry.get_all_tool_specs() - agent.hooks.invoke_callbacks( BeforeModelInvocationEvent( agent=agent, ) ) + tool_specs = agent.tool_registry.get_all_tool_specs() + try: async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): if not isinstance(event, ModelStopReason): From b568864561724eae357295b1a8c420ffb3244daa Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 8 Sep 2025 17:56:40 +0300 Subject: [PATCH 076/104] fix(graph): fix cyclic graph behavior (#768) fix a bug in the Graph multiagent pattern where the reset_on_revisit feature fails to enable cycles and feedback loops. The issue was in the _find_newly_ready_nodes method, which filtered out completed nodes before they could be revisited, making it impossible to implement feedback loops even when reset_on_revisit=True. --------- Co-authored-by: Murat Kaan Meral Co-authored-by: Mackenzie Zastrow --- src/strands/multiagent/graph.py | 25 +- tests/strands/multiagent/test_graph.py | 318 ++++++++++++++++++++----- 2 files changed, 266 insertions(+), 77 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 081193b10..d2838396d 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -469,41 +469,32 @@ async def _execute_graph(self) -> None: ready_nodes.clear() # Execute current batch of ready nodes concurrently - tasks = [ - asyncio.create_task(self._execute_node(node)) - for node in current_batch - if node not in self.state.completed_nodes - ] + tasks = [asyncio.create_task(self._execute_node(node)) for node in current_batch] for task in tasks: await task # Find newly ready nodes after batch execution - ready_nodes.extend(self._find_newly_ready_nodes()) + # We add all nodes in current batch as completed batch, + # because a failure would throw exception and code would not make it here + ready_nodes.extend(self._find_newly_ready_nodes(current_batch)) - def _find_newly_ready_nodes(self) -> list["GraphNode"]: + def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["GraphNode"]: """Find nodes that became ready after the last execution.""" newly_ready = [] for _node_id, node in self.nodes.items(): - if ( - node not in self.state.completed_nodes - and node not in self.state.failed_nodes - and self._is_node_ready_with_conditions(node) - ): + if self._is_node_ready_with_conditions(node, completed_batch): newly_ready.append(node) return newly_ready - def _is_node_ready_with_conditions(self, node: GraphNode) -> bool: + def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list["GraphNode"]) -> bool: """Check if a node is ready considering conditional edges.""" # Get incoming edges to this node incoming_edges = [edge for edge in self.edges if edge.to_node == node] - if not incoming_edges: - return node in self.entry_points - # Check if at least one incoming edge condition is satisfied for edge in incoming_edges: - if edge.from_node in self.state.completed_nodes: + if edge.from_node in completed_batch: if edge.should_traverse(self.state): logger.debug( "from=<%s>, to=<%s> | edge ready via satisfied condition", edge.from_node.node_id, node.node_id diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 9977c54cd..1a598847d 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1,6 +1,6 @@ import asyncio import time -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, call, patch import pytest @@ -318,7 +318,7 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span): @pytest.mark.asyncio async def test_cyclic_graph_execution(mock_strands_tracer, mock_use_span): - """Test execution of a graph with cycles.""" + """Test execution of a graph with cycles and proper exit conditions.""" # Create mock agents with state tracking agent_a = create_mock_agent("agent_a", "Agent A response") agent_b = create_mock_agent("agent_b", "Agent B response") @@ -332,16 +332,33 @@ async def test_cyclic_graph_execution(mock_strands_tracer, mock_use_span): # Create a spy to track reset calls reset_spy = MagicMock() - # Create a graph with a cycle: A -> B -> C -> A + # Create conditions for controlled cycling + def a_to_b_condition(state: GraphState) -> bool: + # A can trigger B if B hasn't been executed yet + b_count = sum(1 for node in state.execution_order if node.node_id == "b") + return b_count == 0 + + def b_to_c_condition(state: GraphState) -> bool: + # B can always trigger C (unconditional) + return True + + def c_to_a_condition(state: GraphState) -> bool: + # C can trigger A only if A has been executed less than 2 times + a_count = sum(1 for node in state.execution_order if node.node_id == "a") + return a_count < 2 + + # Create a graph with conditional cycle: A -> B -> C -> A (with conditions) builder = GraphBuilder() builder.add_node(agent_a, "a") builder.add_node(agent_b, "b") builder.add_node(agent_c, "c") - builder.add_edge("a", "b") - builder.add_edge("b", "c") - builder.add_edge("c", "a") # Creates cycle + builder.add_edge("a", "b", condition=a_to_b_condition) # A -> B only if B not executed + builder.add_edge("b", "c", condition=b_to_c_condition) # B -> C always + builder.add_edge("c", "a", condition=c_to_a_condition) # C -> A only if A executed < 2 times builder.set_entry_point("a") - builder.reset_on_revisit() # Enable state reset on revisit + builder.reset_on_revisit(True) # Enable state reset on revisit + builder.set_max_node_executions(10) # Safety limit + builder.set_execution_timeout(30.0) # Safety timeout # Patch the reset_executor_state method to track calls original_reset = GraphNode.reset_executor_state @@ -353,51 +370,29 @@ def spy_reset(self): with patch.object(GraphNode, "reset_executor_state", spy_reset): graph = builder.build() - # Set a maximum iteration limit to prevent infinite loops - # but ensure we go through the cycle at least twice - # This value is used in the LimitedGraph class below - - # Execute the graph with a task that will cause it to cycle + # Execute the graph with controlled cycling result = await graph.invoke_async("Test cyclic graph execution") # Verify that the graph executed successfully assert result.status == Status.COMPLETED - # Verify that each agent was called at least once - agent_a.invoke_async.assert_called() - agent_b.invoke_async.assert_called() - agent_c.invoke_async.assert_called() - - # Verify that the execution order includes all nodes - assert len(result.execution_order) >= 3 - assert any(node.node_id == "a" for node in result.execution_order) - assert any(node.node_id == "b" for node in result.execution_order) - assert any(node.node_id == "c" for node in result.execution_order) - - # Verify that node state was reset during cyclic execution - # If we have more than 3 nodes in execution_order, at least one node was revisited - if len(result.execution_order) > 3: - # Check that reset_executor_state was called for revisited nodes - reset_spy.assert_called() - - # Count occurrences of each node in execution order - node_counts = {} - for node in result.execution_order: - node_counts[node.node_id] = node_counts.get(node.node_id, 0) + 1 - - # At least one node should appear multiple times - assert any(count > 1 for count in node_counts.values()), "No node was revisited in the cycle" - - # For each node that appears multiple times, verify reset was called - for node_id, count in node_counts.items(): - if count > 1: - # Check that reset was called at least (count-1) times for this node - reset_calls = sum(1 for call in reset_spy.call_args_list if call[0][0] == node_id) - assert reset_calls >= count - 1, ( - f"Node {node_id} appeared {count} times but reset was called {reset_calls} times" - ) - - # Verify all nodes were completed + # Expected execution order: a -> b -> c -> a (4 total executions) + # A executes twice (initial + after c), B executes once, C executes once + assert len(result.execution_order) == 4 + + # Verify execution order + execution_ids = [node.node_id for node in result.execution_order] + assert execution_ids == ["a", "b", "c", "a"] + + # Verify that each agent was called the expected number of times + assert agent_a.invoke_async.call_count == 2 # A executes twice + assert agent_b.invoke_async.call_count == 1 # B executes once + assert agent_c.invoke_async.call_count == 1 # C executes once + + # Verify that node state was reset for the revisited node (A) + assert reset_spy.call_args_list == [call("a")] # Only A should be reset (when revisited) + + # Verify all nodes were completed (final state) assert result.completed_nodes == 3 @@ -423,8 +418,6 @@ def test_graph_builder_validation(): builder.add_node(same_agent, "node2") # Same agent instance, different node_id # Test duplicate node instances in Graph.__init__ - from strands.multiagent.graph import Graph, GraphNode - duplicate_agent = create_mock_agent("duplicate_agent") node1 = GraphNode("node1", duplicate_agent) node2 = GraphNode("node2", duplicate_agent) # Same agent instance @@ -566,7 +559,9 @@ async def test_graph_execution_limits(mock_strands_tracer, mock_use_span): assert result.status == Status.FAILED # Should fail due to limit assert len(result.execution_order) == 2 # Should stop at 2 executions - # Test execution timeout by manipulating start time (like Swarm does) + +@pytest.mark.asyncio +async def test_graph_execution_limits_with_cyclic_graph(mock_strands_tracer, mock_use_span): timeout_agent_a = create_mock_agent("timeout_agent_a", "Response A") timeout_agent_b = create_mock_agent("timeout_agent_b", "Response B") @@ -581,16 +576,28 @@ async def test_graph_execution_limits(mock_strands_tracer, mock_use_span): # Enable reset_on_revisit so the cycle can continue graph = builder.reset_on_revisit(True).set_execution_timeout(5.0).set_max_node_executions(100).build() - # Manipulate the start time to simulate timeout (like Swarm does) - result = await graph.invoke_async("Test execution timeout") - # Manually set start time to simulate timeout condition - graph.state.start_time = time.time() - 10 # Set start time to 10 seconds ago + # Execute the cyclic graph - should hit one of the limits + result = await graph.invoke_async("Test execution limits") - # Check the timeout logic directly - should_continue, reason = graph.state.should_continue(max_node_executions=100, execution_timeout=5.0) + # Should fail due to hitting a limit (either timeout or max executions) + assert result.status == Status.FAILED + # Should have executed many nodes (hitting the limit) + assert len(result.execution_order) >= 50 # Should execute many times before hitting limit + + # Test timeout logic directly (without execution) + test_state = GraphState() + test_state.start_time = time.time() - 10 # Set start time to 10 seconds ago + should_continue, reason = test_state.should_continue(max_node_executions=100, execution_timeout=5.0) assert should_continue is False assert "Execution timed out" in reason + # Test max executions logic directly (without execution) + test_state2 = GraphState() + test_state2.execution_order = [None] * 101 # Simulate 101 executions + should_continue2, reason2 = test_state2.should_continue(max_node_executions=100, execution_timeout=5.0) + assert should_continue2 is False + assert "Max node executions reached" in reason2 + # builder = GraphBuilder() # builder.add_node(slow_agent, "slow") # graph = (builder.set_max_node_executions(1000) # High limit to avoid hitting this @@ -1062,9 +1069,7 @@ async def test_state_reset_only_with_cycles_enabled(): graph = builder.build() # Mock the _execute_node method to test conditional reset logic - import unittest.mock - - with unittest.mock.patch.object(node, "reset_executor_state") as mock_reset: + with patch.object(node, "reset_executor_state") as mock_reset: # Simulate the conditional logic from _execute_node if graph.reset_on_revisit and node in state.completed_nodes: node.reset_executor_state() @@ -1079,7 +1084,7 @@ async def test_state_reset_only_with_cycles_enabled(): builder.reset_on_revisit() graph = builder.build() - with unittest.mock.patch.object(node, "reset_executor_state") as mock_reset: + with patch.object(node, "reset_executor_state") as mock_reset: # Simulate the conditional logic from _execute_node if graph.reset_on_revisit and node in state.completed_nodes: node.reset_executor_state() @@ -1087,3 +1092,196 @@ async def test_state_reset_only_with_cycles_enabled(): # With reset_on_revisit enabled, reset should be called mock_reset.assert_called_once() + + +@pytest.mark.asyncio +async def test_self_loop_functionality(mock_strands_tracer, mock_use_span): + """Test comprehensive self-loop functionality including conditions and reset behavior.""" + # Test basic self-loop with execution counting + self_loop_agent = create_mock_agent("self_loop_agent", "Self loop response") + self_loop_agent.invoke_async = Mock(side_effect=self_loop_agent.invoke_async) + + def loop_condition(state: GraphState) -> bool: + return len(state.execution_order) < 3 + + builder = GraphBuilder() + builder.add_node(self_loop_agent, "self_loop") + builder.add_edge("self_loop", "self_loop", condition=loop_condition) + builder.set_entry_point("self_loop") + builder.reset_on_revisit(True) + builder.set_max_node_executions(10) + builder.set_execution_timeout(30.0) + + graph = builder.build() + result = await graph.invoke_async("Test self loop") + + # Verify basic self-loop functionality + assert result.status == Status.COMPLETED + assert self_loop_agent.invoke_async.call_count == 3 + assert len(result.execution_order) == 3 + assert all(node.node_id == "self_loop" for node in result.execution_order) + + +@pytest.mark.asyncio +async def test_self_loop_functionality_without_reset(mock_strands_tracer, mock_use_span): + loop_agent_no_reset = create_mock_agent("loop_agent", "Loop without reset") + + can_only_be_called_twice: Mock = Mock(side_effect=lambda state: can_only_be_called_twice.call_count <= 2) + + builder = GraphBuilder() + builder.add_node(loop_agent_no_reset, "loop_node") + builder.add_edge("loop_node", "loop_node", condition=can_only_be_called_twice) + builder.set_entry_point("loop_node") + builder.reset_on_revisit(False) # Disable state reset + builder.set_max_node_executions(10) + + graph = builder.build() + result = await graph.invoke_async("Test self loop without reset") + + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 2 + + mock_strands_tracer.start_multiagent_span.assert_called() + mock_use_span.assert_called() + + +@pytest.mark.asyncio +async def test_complex_self_loop(mock_strands_tracer, mock_use_span): + """Test complex self-loop scenarios including multi-node graphs and multiple self-loops.""" + start_agent = create_mock_agent("start_agent", "Start") + loop_agent = create_mock_agent("loop_agent", "Loop") + end_agent = create_mock_agent("end_agent", "End") + + def loop_condition(state: GraphState) -> bool: + loop_count = sum(1 for node in state.execution_order if node.node_id == "loop_node") + return loop_count < 2 + + def end_condition(state: GraphState) -> bool: + loop_count = sum(1 for node in state.execution_order if node.node_id == "loop_node") + return loop_count >= 2 + + builder = GraphBuilder() + builder.add_node(start_agent, "start_node") + builder.add_node(loop_agent, "loop_node") + builder.add_node(end_agent, "end_node") + builder.add_edge("start_node", "loop_node") + builder.add_edge("loop_node", "loop_node", condition=loop_condition) + builder.add_edge("loop_node", "end_node", condition=end_condition) + builder.set_entry_point("start_node") + builder.reset_on_revisit(True) + builder.set_max_node_executions(10) + + graph = builder.build() + result = await graph.invoke_async("Test complex graph with self loops") + + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 4 # start -> loop -> loop -> end + assert [node.node_id for node in result.execution_order] == ["start_node", "loop_node", "loop_node", "end_node"] + assert start_agent.invoke_async.call_count == 1 + assert loop_agent.invoke_async.call_count == 2 + assert end_agent.invoke_async.call_count == 1 + + +@pytest.mark.asyncio +async def test_multiple_nodes_with_self_loops(mock_strands_tracer, mock_use_span): + agent_a = create_mock_agent("agent_a", "Agent A") + agent_b = create_mock_agent("agent_b", "Agent B") + + def condition_a(state: GraphState) -> bool: + return sum(1 for node in state.execution_order if node.node_id == "a") < 2 + + def condition_b(state: GraphState) -> bool: + return sum(1 for node in state.execution_order if node.node_id == "b") < 2 + + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_edge("a", "a", condition=condition_a) + builder.add_edge("b", "b", condition=condition_b) + builder.add_edge("a", "b") + builder.set_entry_point("a") + builder.reset_on_revisit(True) + builder.set_max_node_executions(15) + + graph = builder.build() + result = await graph.invoke_async("Test multiple self loops") + + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 4 # a -> a -> b -> b + assert agent_a.invoke_async.call_count == 2 + assert agent_b.invoke_async.call_count == 2 + + mock_strands_tracer.start_multiagent_span.assert_called() + mock_use_span.assert_called() + + +@pytest.mark.asyncio +async def test_self_loop_state_reset(): + """Test self-loop edge cases including state reset, failure handling, and infinite loop prevention.""" + agent = create_mock_agent("stateful_agent", "Stateful response") + agent.state = AgentState() + + def loop_condition(state: GraphState) -> bool: + return len(state.execution_order) < 3 + + builder = GraphBuilder() + node = builder.add_node(agent, "stateful_node") + builder.add_edge("stateful_node", "stateful_node", condition=loop_condition) + builder.set_entry_point("stateful_node") + builder.reset_on_revisit(True) + builder.set_max_node_executions(10) + + node.reset_executor_state = Mock(wraps=node.reset_executor_state) + + graph = builder.build() + result = await graph.invoke_async("Test state reset") + + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 3 + assert node.reset_executor_state.call_count >= 2 # Reset called for revisits + + +@pytest.mark.asyncio +async def test_infinite_loop_prevention(): + infinite_agent = create_mock_agent("infinite_agent", "Infinite loop") + + def always_true_condition(state: GraphState) -> bool: + return True + + builder = GraphBuilder() + builder.add_node(infinite_agent, "infinite_node") + builder.add_edge("infinite_node", "infinite_node", condition=always_true_condition) + builder.set_entry_point("infinite_node") + builder.reset_on_revisit(True) + builder.set_max_node_executions(5) + + graph = builder.build() + result = await graph.invoke_async("Test infinite loop prevention") + + assert result.status == Status.FAILED + assert len(result.execution_order) == 5 + + +@pytest.mark.asyncio +async def test_infinite_loop_prevention_self_loops(): + multi_agent = create_mock_multi_agent("multi_agent", "Multi-agent response") + loop_count = 0 + + def multi_loop_condition(state: GraphState) -> bool: + nonlocal loop_count + loop_count += 1 + return loop_count <= 2 + + builder = GraphBuilder() + builder.add_node(multi_agent, "multi_node") + builder.add_edge("multi_node", "multi_node", condition=multi_loop_condition) + builder.set_entry_point("multi_node") + builder.reset_on_revisit(True) + builder.set_max_node_executions(10) + + graph = builder.build() + result = await graph.invoke_async("Test multi-agent self loop") + + assert result.status == Status.COMPLETED + assert len(result.execution_order) >= 2 + assert multi_agent.invoke_async.call_count >= 2 From 8cb53d3531149ba4998e68f99aa712550573c620 Mon Sep 17 00:00:00 2001 From: Aryan Orpe Date: Mon, 8 Sep 2025 22:10:13 +0400 Subject: [PATCH 077/104] fix(models): filter reasoningContent in Bedrock requests using DeepSeek (#652) * Fix: strip reasoningContent from messages before sending to Bedrock to avoid ValidationException * Using Message class instead of dict in _strip_reasoning_content_from_message(). * fix(models): filter reasoningContent blocks on Bedrock requests using DeepSeek * fix: formatting and linting * fix: formatting and linting * remove unrelated registry formatting * linting * add log --------- Co-authored-by: Dean Schmigelski --- src/strands/models/bedrock.py | 35 ++++++++++++++---- tests/strands/models/test_bedrock.py | 53 ++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 6 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 8a6d5116f..aa19b114d 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -293,7 +293,9 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html """ cleaned_messages = [] + filtered_unknown_members = False + dropped_deepseek_reasoning_content = False for message in messages: cleaned_content: list[ContentBlock] = [] @@ -304,6 +306,12 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: filtered_unknown_members = True continue + # DeepSeek models have issues with reasoningContent + # TODO: Replace with systematic model configuration registry (https://github.com/strands-agents/sdk-python/issues/780) + if "deepseek" in self.config["model_id"].lower() and "reasoningContent" in content_block: + dropped_deepseek_reasoning_content = True + continue + if "toolResult" in content_block: # Create a new content block with only the cleaned toolResult tool_result: ToolResult = content_block["toolResult"] @@ -327,14 +335,19 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: # Keep other content blocks as-is cleaned_content.append(content_block) - # Create new message with cleaned content - cleaned_message: Message = Message(content=cleaned_content, role=message["role"]) - cleaned_messages.append(cleaned_message) + # Create new message with cleaned content (skip if empty for DeepSeek) + if cleaned_content: + cleaned_message: Message = Message(content=cleaned_content, role=message["role"]) + cleaned_messages.append(cleaned_message) if filtered_unknown_members: logger.warning( "Filtered out SDK_UNKNOWN_MEMBER content blocks from messages, consider upgrading boto3 version" ) + if dropped_deepseek_reasoning_content: + logger.debug( + "Filtered DeepSeek reasoningContent content blocks from messages - https://api-docs.deepseek.com/guides/reasoning_model#multi-round-conversation" + ) return cleaned_messages @@ -386,7 +399,8 @@ def _generate_redaction_events(self) -> list[StreamEvent]: { "redactContent": { "redactAssistantContentMessage": self.config.get( - "guardrail_redact_output_message", "[Assistant output redacted.]" + "guardrail_redact_output_message", + "[Assistant output redacted.]", ) } } @@ -699,7 +713,11 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any + self, + output_model: Type[T], + prompt: Messages, + system_prompt: Optional[str] = None, + **kwargs: Any, ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. @@ -714,7 +732,12 @@ async def structured_output( """ tool_spec = convert_pydantic_to_tool_spec(output_model) - response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs) + response = self.stream( + messages=prompt, + tool_specs=[tool_spec], + system_prompt=system_prompt, + **kwargs, + ) async for event in streaming.process_stream(response): yield event diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 13918b6ea..f2e459bde 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1392,3 +1392,56 @@ def test_format_request_filters_sdk_unknown_member_content_blocks(model, model_i for block in content: assert "SDK_UNKNOWN_MEMBER" not in block + + +@pytest.mark.asyncio +async def test_stream_deepseek_filters_reasoning_content(bedrock_client, alist): + """Test that DeepSeek models filter reasoningContent from messages during streaming.""" + model = BedrockModel(model_id="us.deepseek.r1-v1:0") + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + { + "role": "assistant", + "content": [ + {"text": "Response"}, + {"reasoningContent": {"reasoningText": {"text": "Thinking..."}}}, + ], + }, + ] + + bedrock_client.converse_stream.return_value = {"stream": []} + + await alist(model.stream(messages)) + + # Verify the request was made with filtered messages (no reasoningContent) + call_args = bedrock_client.converse_stream.call_args[1] + sent_messages = call_args["messages"] + + assert len(sent_messages) == 2 + assert sent_messages[0]["content"] == [{"text": "Hello"}] + assert sent_messages[1]["content"] == [{"text": "Response"}] + + +@pytest.mark.asyncio +async def test_stream_deepseek_skips_empty_messages(bedrock_client, alist): + """Test that DeepSeek models skip messages that would be empty after filtering reasoningContent.""" + model = BedrockModel(model_id="us.deepseek.r1-v1:0") + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"reasoningContent": {"reasoningText": {"text": "Only reasoning..."}}}]}, + {"role": "user", "content": [{"text": "Follow up"}]}, + ] + + bedrock_client.converse_stream.return_value = {"stream": []} + + await alist(model.stream(messages)) + + # Verify the request was made with only non-empty messages + call_args = bedrock_client.converse_stream.call_args[1] + sent_messages = call_args["messages"] + + assert len(sent_messages) == 2 + assert sent_messages[0]["content"] == [{"text": "Hello"}] + assert sent_messages[1]["content"] == [{"text": "Follow up"}] From c142e7ad2453fe5c305e30a2ac30759c7f4b527c Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Mon, 8 Sep 2025 17:05:28 -0400 Subject: [PATCH 078/104] docs: cleanup docs so the yields section renders correctly (#820) --- src/strands/agent/agent.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 1e64f5adb..05e15a5b1 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -547,12 +547,12 @@ async def stream_async( Yields: An async iterator that yields events. Each event is a dictionary containing - information about the current state of processing, such as: + information about the current state of processing, such as: - - data: Text content being generated - - complete: Whether this is the final chunk - - current_tool_use: Information about tools being executed - - And other event data provided by the callback handler + - data: Text content being generated + - complete: Whether this is the final chunk + - current_tool_use: Information about tools being executed + - And other event data provided by the callback handler Raises: Exception: Any exceptions from the agent invocation will be propagated to the caller. From f185c52155fda1d54e03ed29f6c29b8d8b0125a2 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Mon, 8 Sep 2025 17:14:43 -0400 Subject: [PATCH 079/104] feat: Warn on unknown model configuration properties (#819) Implement the ability for all built-in providers to emit a warning when unknown configuration properties are included. Co-authored-by: Mackenzie Zastrow --- src/strands/models/_config_validation.py | 27 ++++++++++++++++ src/strands/models/anthropic.py | 3 ++ src/strands/models/bedrock.py | 2 ++ src/strands/models/litellm.py | 3 ++ src/strands/models/llamaapi.py | 3 ++ src/strands/models/mistral.py | 3 ++ src/strands/models/ollama.py | 3 ++ src/strands/models/openai.py | 3 ++ src/strands/models/sagemaker.py | 4 +++ src/strands/models/writer.py | 3 ++ tests/conftest.py | 11 +++++++ tests/strands/models/test_anthropic.py | 18 +++++++++++ tests/strands/models/test_bedrock.py | 18 +++++++++++ tests/strands/models/test_litellm.py | 18 +++++++++++ tests/strands/models/test_llamaapi.py | 18 +++++++++++ tests/strands/models/test_mistral.py | 18 +++++++++++ tests/strands/models/test_ollama.py | 18 +++++++++++ tests/strands/models/test_openai.py | 18 +++++++++++ tests/strands/models/test_sagemaker.py | 41 ++++++++++++++++++++++++ tests/strands/models/test_writer.py | 18 +++++++++++ 20 files changed, 250 insertions(+) create mode 100644 src/strands/models/_config_validation.py diff --git a/src/strands/models/_config_validation.py b/src/strands/models/_config_validation.py new file mode 100644 index 000000000..085449bb8 --- /dev/null +++ b/src/strands/models/_config_validation.py @@ -0,0 +1,27 @@ +"""Configuration validation utilities for model providers.""" + +import warnings +from typing import Any, Mapping, Type + +from typing_extensions import get_type_hints + + +def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None: + """Validate that config keys match the TypedDict fields. + + Args: + config_dict: Dictionary of configuration parameters + config_class: TypedDict class to validate against + """ + valid_keys = set(get_type_hints(config_class).keys()) + provided_keys = set(config_dict.keys()) + invalid_keys = provided_keys - valid_keys + + if invalid_keys: + warnings.warn( + f"Invalid configuration parameters: {sorted(invalid_keys)}." + f"\nValid parameters are: {sorted(valid_keys)}." + f"\n" + f"\nSee https://github.com/strands-agents/sdk-python/issues/815", + stacklevel=4, + ) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 29cb40d40..06dc816f2 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -19,6 +19,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolSpec +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -67,6 +68,7 @@ def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_conf For a complete list of supported arguments, see https://docs.anthropic.com/en/api/client-sdks. **model_config: Configuration options for the Anthropic model. """ + validate_config_keys(model_config, self.AnthropicConfig) self.config = AnthropicModel.AnthropicConfig(**model_config) logger.debug("config=<%s> | initializing", self.config) @@ -81,6 +83,7 @@ def update_config(self, **model_config: Unpack[AnthropicConfig]) -> None: # typ Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.AnthropicConfig) self.config.update(model_config) @override diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index aa19b114d..f18422191 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -24,6 +24,7 @@ ) from ..types.streaming import CitationsDelta, StreamEvent from ..types.tools import ToolResult, ToolSpec +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -166,6 +167,7 @@ def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.BedrockConfig) self.config.update(model_config) @override diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index c1e99f1a2..9a31e82df 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -15,6 +15,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent from ..types.tools import ToolSpec +from ._config_validation import validate_config_keys from .openai import OpenAIModel logger = logging.getLogger(__name__) @@ -49,6 +50,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: **model_config: Configuration options for the LiteLLM model. """ self.client_args = client_args or {} + validate_config_keys(model_config, self.LiteLLMConfig) self.config = dict(model_config) logger.debug("config=<%s> | initializing", self.config) @@ -60,6 +62,7 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.LiteLLMConfig) self.config.update(model_config) @override diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 421b06e52..57ff85c66 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -19,6 +19,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent, Usage from ..types.tools import ToolResult, ToolSpec, ToolUse +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -60,6 +61,7 @@ def __init__( client_args: Arguments for the Llama API client. **model_config: Configuration options for the Llama API model. """ + validate_config_keys(model_config, self.LlamaConfig) self.config = LlamaAPIModel.LlamaConfig(**model_config) logger.debug("config=<%s> | initializing", self.config) @@ -75,6 +77,7 @@ def update_config(self, **model_config: Unpack[LlamaConfig]) -> None: # type: i Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.LlamaConfig) self.config.update(model_config) @override diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 8855b6d64..401dde98e 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -16,6 +16,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolResult, ToolSpec, ToolUse +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -82,6 +83,7 @@ def __init__( if not 0.0 <= top_p <= 1.0: raise ValueError(f"top_p must be between 0.0 and 1.0, got {top_p}") + validate_config_keys(model_config, self.MistralConfig) self.config = MistralModel.MistralConfig(**model_config) # Set default stream to True if not specified @@ -101,6 +103,7 @@ def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.MistralConfig) self.config.update(model_config) @override diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 76cd87d72..4025dc062 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -14,6 +14,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolSpec +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -70,6 +71,7 @@ def __init__( """ self.host = host self.client_args = ollama_client_args or {} + validate_config_keys(model_config, self.OllamaConfig) self.config = OllamaModel.OllamaConfig(**model_config) logger.debug("config=<%s> | initializing", self.config) @@ -81,6 +83,7 @@ def update_config(self, **model_config: Unpack[OllamaConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.OllamaConfig) self.config.update(model_config) @override diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 1076fbae4..16eb4defe 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -17,6 +17,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent from ..types.tools import ToolResult, ToolSpec, ToolUse +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -61,6 +62,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: For a complete list of supported arguments, see https://pypi.org/project/openai/. **model_config: Configuration options for the OpenAI model. """ + validate_config_keys(model_config, self.OpenAIConfig) self.config = dict(model_config) logger.debug("config=<%s> | initializing", self.config) @@ -75,6 +77,7 @@ def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.OpenAIConfig) self.config.update(model_config) @override diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 9cfe27d9e..74069b895 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -15,6 +15,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent from ..types.tools import ToolResult, ToolSpec +from ._config_validation import validate_config_keys from .openai import OpenAIModel T = TypeVar("T", bound=BaseModel) @@ -146,6 +147,8 @@ def __init__( boto_session: Boto Session to use when calling the SageMaker Runtime. boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client. """ + validate_config_keys(endpoint_config, self.SageMakerAIEndpointConfig) + validate_config_keys(payload_config, self.SageMakerAIPayloadSchema) payload_config.setdefault("stream", True) payload_config.setdefault("tool_results_as_user_messages", False) self.endpoint_config = dict(endpoint_config) @@ -180,6 +183,7 @@ def update_config(self, **endpoint_config: Unpack[SageMakerAIEndpointConfig]) -> Args: **endpoint_config: Configuration overrides. """ + validate_config_keys(endpoint_config, self.SageMakerAIEndpointConfig) self.endpoint_config.update(endpoint_config) @override diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index f6a3da3d8..9bcdaad42 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -17,6 +17,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolResult, ToolSpec, ToolUse +from ._config_validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -53,6 +54,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: client_args: Arguments for the Writer client (e.g., api_key, base_url, timeout, etc.). **model_config: Configuration options for the Writer model. """ + validate_config_keys(model_config, self.WriterConfig) self.config = WriterModel.WriterConfig(**model_config) logger.debug("config=<%s> | initializing", self.config) @@ -67,6 +69,7 @@ def update_config(self, **model_config: Unpack[WriterConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.WriterConfig) self.config.update(model_config) @override diff --git a/tests/conftest.py b/tests/conftest.py index 3b82e362c..f2a8909cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import logging import os import sys +import warnings import boto3 import moto @@ -107,3 +108,13 @@ def generate(generator): return events, stop.value return generate + + +## Warnings + + +@pytest.fixture +def captured_warnings(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + yield w diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 5e8d69ea7..9a7a4be11 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -767,3 +767,21 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls, tru_result = events[-1] exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result + + +def test_config_validation_warns_on_unknown_keys(anthropic_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + AnthropicModel(model_id="test-model", max_tokens=100, invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index f2e459bde..624eec6e9 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1445,3 +1445,21 @@ async def test_stream_deepseek_skips_empty_messages(bedrock_client, alist): assert len(sent_messages) == 2 assert sent_messages[0]["content"] == [{"text": "Hello"}] assert sent_messages[1]["content"] == [{"text": "Follow up"}] + + +def test_config_validation_warns_on_unknown_keys(bedrock_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + BedrockModel(model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 44b6df63b..9140cadcc 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -252,3 +252,21 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result + + +def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings): + """Test that unknown config keys emit a warning.""" + LiteLLMModel(client_args={"api_key": "test"}, model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 309dac2e9..712ef8b7a 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -361,3 +361,21 @@ def test_format_chunk_other(model): with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): model.format_chunk(event) + + +def test_config_validation_warns_on_unknown_keys(llamaapi_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + LlamaAPIModel(model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 2a78024f2..9b3f62a31 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -539,3 +539,21 @@ async def test_structured_output_invalid_json(mistral_client, model, test_output with pytest.raises(ValueError, match="Failed to parse tool call arguments into model"): stream = model.structured_output(test_output_model_cls, prompt) await anext(stream) + + +def test_config_validation_warns_on_unknown_keys(mistral_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + MistralModel(model_id="test-model", max_tokens=100, invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index c3fb7736e..9a63a3214 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -516,3 +516,21 @@ async def test_structured_output(ollama_client, model, test_output_model_cls, al tru_result = events[-1] exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result + + +def test_config_validation_warns_on_unknown_keys(ollama_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + OllamaModel("http://localhost:11434", model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index a7c97701c..00cae7447 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -583,3 +583,21 @@ async def test_structured_output(openai_client, model, test_output_model_cls, al tru_result = events[-1] exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result + + +def test_config_validation_warns_on_unknown_keys(openai_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + OpenAIModel({"api_key": "test"}, model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index ba395b2d6..a9071c7e2 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -572,3 +572,44 @@ def test_tool_call(self): assert tool2.type == "function" assert tool2.function.name == "get_time" assert tool2.function.arguments == '{"timezone": "UTC"}' + + +def test_config_validation_warns_on_unknown_keys_in_endpoint(boto_session, captured_warnings): + """Test that unknown config keys emit a warning.""" + endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1", "invalid_param": "test"} + payload_config = {"max_tokens": 1024} + + SageMakerAIModel( + endpoint_config=endpoint_config, + payload_config=payload_config, + boto_session=boto_session, + ) + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_config_validation_warns_on_unknown_keys_in_payload(boto_session, captured_warnings): + """Test that unknown config keys emit a warning.""" + endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"} + payload_config = {"max_tokens": 1024, "invalid_param": "test"} + + SageMakerAIModel( + endpoint_config=endpoint_config, + payload_config=payload_config, + boto_session=boto_session, + ) + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests/strands/models/test_writer.py b/tests/strands/models/test_writer.py index f7748cfdb..75896ca68 100644 --- a/tests/strands/models/test_writer.py +++ b/tests/strands/models/test_writer.py @@ -380,3 +380,21 @@ async def test_stream_with_empty_choices(writer_client, model, model_id): "stream_options": {"include_usage": True}, } writer_client.chat.chat.assert_called_once_with(**expected_request) + + +def test_config_validation_warns_on_unknown_keys(writer_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + WriterModel({"api_key": "test"}, model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) From 1f27488d5ec1f38db3a10778285efada6ffd3822 Mon Sep 17 00:00:00 2001 From: Hamed Soleimani Date: Mon, 8 Sep 2025 14:15:07 -0700 Subject: [PATCH 080/104] fix: do not block asyncio event loop between retries (#805) --- src/strands/event_loop/event_loop.py | 4 ++-- .../strands/agent/hooks/test_agent_events.py | 10 ++++---- tests/strands/event_loop/test_event_loop.py | 24 ++++++++++--------- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 099a524c6..1d437e944 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -8,8 +8,8 @@ 4. Manage recursive execution cycles """ +import asyncio import logging -import time import uuid from typing import TYPE_CHECKING, Any, AsyncGenerator @@ -189,7 +189,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> MAX_ATTEMPTS, attempt + 1, ) - time.sleep(current_delay) + await asyncio.sleep(current_delay) current_delay = min(current_delay * 2, MAX_DELAY) yield EventLoopThrottleEvent(delay=current_delay) diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 07f55b724..01bfc5409 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -31,8 +31,10 @@ async def streaming_tool(): @pytest.fixture -def mock_time(): - with unittest.mock.patch.object(strands.event_loop.event_loop, "time") as mock: +def mock_sleep(): + with unittest.mock.patch.object( + strands.event_loop.event_loop.asyncio, "sleep", new_callable=unittest.mock.AsyncMock + ) as mock: yield mock @@ -322,7 +324,7 @@ async def test_stream_e2e_success(alist): @pytest.mark.asyncio -async def test_stream_e2e_throttle_and_redact(alist, mock_time): +async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): model = MagicMock() model.stream.side_effect = [ ModelThrottledException("ThrottlingException | ConverseStream"), @@ -389,7 +391,7 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_time): async def test_event_loop_cycle_text_response_throttling_early_end( agenerator, alist, - mock_time, + mock_sleep, ): model = MagicMock() model.stream.side_effect = [ diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 68f9cc5ab..9d9e20863 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -26,8 +26,10 @@ @pytest.fixture -def mock_time(): - with unittest.mock.patch.object(strands.event_loop.event_loop, "time") as mock: +def mock_sleep(): + with unittest.mock.patch.object( + strands.event_loop.event_loop.asyncio, "sleep", new_callable=unittest.mock.AsyncMock + ) as mock: yield mock @@ -186,7 +188,7 @@ async def test_event_loop_cycle_text_response( @pytest.mark.asyncio async def test_event_loop_cycle_text_response_throttling( - mock_time, + mock_sleep, agent, model, agenerator, @@ -215,12 +217,12 @@ async def test_event_loop_cycle_text_response_throttling( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state # Verify that sleep was called once with the initial delay - mock_time.sleep.assert_called_once() + mock_sleep.assert_called_once() @pytest.mark.asyncio async def test_event_loop_cycle_exponential_backoff( - mock_time, + mock_sleep, agent, model, agenerator, @@ -254,13 +256,13 @@ async def test_event_loop_cycle_exponential_backoff( # Verify that sleep was called with increasing delays # Initial delay is 4, then 8, then 16 - assert mock_time.sleep.call_count == 3 - assert mock_time.sleep.call_args_list == [call(4), call(8), call(16)] + assert mock_sleep.call_count == 3 + assert mock_sleep.call_args_list == [call(4), call(8), call(16)] @pytest.mark.asyncio async def test_event_loop_cycle_text_response_throttling_exceeded( - mock_time, + mock_sleep, agent, model, alist, @@ -281,7 +283,7 @@ async def test_event_loop_cycle_text_response_throttling_exceeded( ) await alist(stream) - mock_time.sleep.assert_has_calls( + mock_sleep.assert_has_calls( [ call(4), call(8), @@ -687,7 +689,7 @@ async def test_event_loop_tracing_with_throttling_exception( ] # Mock the time.sleep function to speed up the test - with patch("strands.event_loop.event_loop.time.sleep"): + with patch("strands.event_loop.event_loop.asyncio.sleep", new_callable=unittest.mock.AsyncMock): stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -816,7 +818,7 @@ async def test_prepare_next_cycle_in_tool_execution(agent, model, tool_stream, a @pytest.mark.asyncio -async def test_event_loop_cycle_exception_model_hooks(mock_time, agent, model, agenerator, alist, hook_provider): +async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, agenerator, alist, hook_provider): """Test that model hooks are correctly emitted even when throttled.""" # Set up the model to raise throttling exceptions multiple times before succeeding exception = ModelThrottledException("ThrottlingException | ConverseStream") From 54206796c609d923f59dbeccfaa8213a74c9a57e Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Mon, 8 Sep 2025 17:32:54 -0400 Subject: [PATCH 081/104] feat: improve structured output tool circular reference handling (#817) * feat: improve structured output tool circular reference handling and optional field detection - Move circular reference detection earlier in schema flattening process - Simplify optional field detection using field.is_required() instead of Union type inspection - Add comprehensive test coverage for circular reference scenarios - Fix handling of fields with default values that make them optional --- src/strands/tools/structured_output.py | 23 +----- tests/strands/tools/test_structured_output.py | 82 ++++++++++++++++++- 2 files changed, 84 insertions(+), 21 deletions(-) diff --git a/src/strands/tools/structured_output.py b/src/strands/tools/structured_output.py index 6f2739d88..2c5922925 100644 --- a/src/strands/tools/structured_output.py +++ b/src/strands/tools/structured_output.py @@ -27,16 +27,16 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: "properties": {}, } - # Add title if present if "title" in schema: flattened["title"] = schema["title"] - # Add description from schema if present, or use model docstring if "description" in schema and schema["description"]: flattened["description"] = schema["description"] # Process properties required_props: list[str] = [] + if "properties" not in schema and "$ref" in schema: + raise ValueError("Circular reference detected and not supported.") if "properties" in schema: required_props = [] for prop_name, prop_value in schema["properties"].items(): @@ -76,9 +76,6 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: if len(required_props) > 0: flattened["required"] = required_props - else: - raise ValueError("Circular reference detected and not supported") - return flattened @@ -325,21 +322,7 @@ def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> continue field_type = field.annotation - - # Handle Optional types - is_optional = False - if ( - field_type is not None - and hasattr(field_type, "__origin__") - and field_type.__origin__ is Union - and hasattr(field_type, "__args__") - ): - # Look for Optional[BaseModel] - for arg in field_type.__args__: - if arg is type(None): - is_optional = True - elif isinstance(arg, type) and issubclass(arg, BaseModel): - field_type = arg + is_optional = not field.is_required() # If this is a BaseModel field, expand its properties with full details if isinstance(field_type, type) and issubclass(field_type, BaseModel): diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index 97b68a34c..fe9b55334 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import List, Literal, Optional import pytest from pydantic import BaseModel, Field @@ -157,6 +157,7 @@ def test_convert_pydantic_to_tool_spec_multiple_same_type(): "user2": { "type": ["object", "null"], "description": "The second user", + "title": "UserWithPlanet", "properties": { "name": {"description": "The name of the user", "title": "Name", "type": "string"}, "age": { @@ -208,6 +209,85 @@ class NodeWithCircularRef(BaseModel): convert_pydantic_to_tool_spec(NodeWithCircularRef) +def test_convert_pydantic_with_circular_required_dependency(): + """Test that the tool handles circular dependencies gracefully.""" + + class NodeWithCircularRef(BaseModel): + """A node with a circular reference to itself.""" + + name: str = Field(description="The name of the node") + parent: "NodeWithCircularRef" + + with pytest.raises(ValueError, match="Circular reference detected and not supported"): + convert_pydantic_to_tool_spec(NodeWithCircularRef) + + +def test_convert_pydantic_with_circular_optional_dependency(): + """Test that the tool handles circular dependencies gracefully.""" + + class NodeWithCircularRef(BaseModel): + """A node with a circular reference to itself.""" + + name: str = Field(description="The name of the node") + parent: Optional["NodeWithCircularRef"] = None + + with pytest.raises(ValueError, match="Circular reference detected and not supported"): + convert_pydantic_to_tool_spec(NodeWithCircularRef) + + +def test_convert_pydantic_with_circular_optional_dependenc_not_using_optional_typing(): + """Test that the tool handles circular dependencies gracefully.""" + + class NodeWithCircularRef(BaseModel): + """A node with a circular reference to itself.""" + + name: str = Field(description="The name of the node") + parent: "NodeWithCircularRef" = None + + with pytest.raises(ValueError, match="Circular reference detected and not supported"): + convert_pydantic_to_tool_spec(NodeWithCircularRef) + + +def test_conversion_works_with_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501 + class Family(BaseModel): + ages: List[str] = Field(default_factory=list) + names: List[str] = Field(default_factory=list) + + converted_output = convert_pydantic_to_tool_spec(Family) + expected_output = { + "name": "Family", + "description": "Family structured output tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "ages": { + "items": {"type": "string"}, + "title": "Ages", + "type": ["array", "null"], + }, + "names": { + "items": {"type": "string"}, + "title": "Names", + "type": ["array", "null"], + }, + }, + "title": "Family", + } + }, + } + assert converted_output == expected_output + + +def test_marks_fields_as_optional_for_model_w_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501 + class Family(BaseModel): + ages: List[str] = Field(default_factory=list) + names: List[str] = Field(default_factory=list) + + converted_output = convert_pydantic_to_tool_spec(Family) + assert "null" in converted_output["inputSchema"]["json"]["properties"]["ages"]["type"] + + def test_convert_pydantic_with_custom_description(): """Test that custom descriptions override model docstrings.""" From 6ab1aca789a524a4a35d3d2623edfe009a8e5160 Mon Sep 17 00:00:00 2001 From: ratish <114130421+Ratish1@users.noreply.github.com> Date: Tue, 9 Sep 2025 19:07:29 +0530 Subject: [PATCH 082/104] fix(tools/loader): load and register all decorated @tool functions from file path (#742) - Collect all DecoratedFunctionTool objects when loading a .py file and return list when multiple exist - Normalize loader results and register each AgentTool separately in registry - Add normalize_loaded_tools helper and test for multiple decorated tools --------- Co-authored-by: ratish Co-authored-by: Mackenzie Zastrow --- src/strands/tools/loader.py | 122 +++++++++++++++++----------- src/strands/tools/mcp/mcp_client.py | 11 ++- src/strands/tools/registry.py | 10 +-- tests/strands/tools/test_loader.py | 75 +++++++++++++++++ 4 files changed, 160 insertions(+), 58 deletions(-) diff --git a/src/strands/tools/loader.py b/src/strands/tools/loader.py index 56433324e..5935077db 100644 --- a/src/strands/tools/loader.py +++ b/src/strands/tools/loader.py @@ -4,8 +4,9 @@ import logging import os import sys +import warnings from pathlib import Path -from typing import cast +from typing import List, cast from ..types.tools import AgentTool from .decorator import DecoratedFunctionTool @@ -18,60 +19,42 @@ class ToolLoader: """Handles loading of tools from different sources.""" @staticmethod - def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: - """Load a Python tool module. - - Args: - tool_path: Path to the Python tool file. - tool_name: Name of the tool. + def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: + """Load a Python tool module and return all discovered function-based tools as a list. - Returns: - Tool instance. - - Raises: - AttributeError: If required attributes are missing from the tool module. - ImportError: If there are issues importing the tool module. - TypeError: If the tool function is not callable. - ValueError: If function in module is not a valid tool. - Exception: For other errors during tool loading. + This method always returns a list of AgentTool (possibly length 1). It is the + canonical API for retrieving multiple tools from a single Python file. """ try: - # Check if tool_path is in the format "package.module:function"; but keep in mind windows whose file path - # could have a colon so also ensure that it's not a file + # Support module:function style (e.g. package.module:function) if not os.path.exists(tool_path) and ":" in tool_path: module_path, function_name = tool_path.rsplit(":", 1) logger.debug("tool_name=<%s>, module_path=<%s> | importing tool from path", function_name, module_path) try: - # Import the module module = __import__(module_path, fromlist=["*"]) - - # Get the function - if not hasattr(module, function_name): - raise AttributeError(f"Module {module_path} has no function named {function_name}") - - func = getattr(module, function_name) - - if isinstance(func, DecoratedFunctionTool): - logger.debug( - "tool_name=<%s>, module_path=<%s> | found function-based tool", function_name, module_path - ) - # mypy has problems converting between DecoratedFunctionTool <-> AgentTool - return cast(AgentTool, func) - else: - raise ValueError( - f"Function {function_name} in {module_path} is not a valid tool (missing @tool decorator)" - ) - except ImportError as e: raise ImportError(f"Failed to import module {module_path}: {str(e)}") from e + if not hasattr(module, function_name): + raise AttributeError(f"Module {module_path} has no function named {function_name}") + + func = getattr(module, function_name) + if isinstance(func, DecoratedFunctionTool): + logger.debug( + "tool_name=<%s>, module_path=<%s> | found function-based tool", function_name, module_path + ) + return [cast(AgentTool, func)] + else: + raise ValueError( + f"Function {function_name} in {module_path} is not a valid tool (missing @tool decorator)" + ) + # Normal file-based tool loading abs_path = str(Path(tool_path).resolve()) - logger.debug("tool_path=<%s> | loading python tool from path", abs_path) - # First load the module to get TOOL_SPEC and check for Lambda deployment + # Load the module by spec spec = importlib.util.spec_from_file_location(tool_name, abs_path) if not spec: raise ImportError(f"Could not create spec for {tool_name}") @@ -82,24 +65,26 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: sys.modules[tool_name] = module spec.loader.exec_module(module) - # First, check for function-based tools with @tool decorator + # Collect function-based tools decorated with @tool + function_tools: List[AgentTool] = [] for attr_name in dir(module): attr = getattr(module, attr_name) if isinstance(attr, DecoratedFunctionTool): logger.debug( "tool_name=<%s>, tool_path=<%s> | found function-based tool in path", attr_name, tool_path ) - # mypy has problems converting between DecoratedFunctionTool <-> AgentTool - return cast(AgentTool, attr) + function_tools.append(cast(AgentTool, attr)) + + if function_tools: + return function_tools - # If no function-based tools found, fall back to traditional module-level tool + # Fall back to module-level TOOL_SPEC + function tool_spec = getattr(module, "TOOL_SPEC", None) if not tool_spec: raise AttributeError( f"Tool {tool_name} missing TOOL_SPEC (neither at module level nor as a decorated function)" ) - # Standard local tool loading tool_func_name = tool_name if not hasattr(module, tool_func_name): raise AttributeError(f"Tool {tool_name} missing function {tool_func_name}") @@ -108,22 +93,61 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: if not callable(tool_func): raise TypeError(f"Tool {tool_name} function is not callable") - return PythonAgentTool(tool_name, tool_spec, tool_func) + return [PythonAgentTool(tool_name, tool_spec, tool_func)] except Exception: - logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool", tool_name, sys.path) + logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool(s)", tool_name, sys.path) raise + @staticmethod + def load_python_tool(tool_path: str, tool_name: str) -> AgentTool: + """DEPRECATED: Load a Python tool module and return a single AgentTool for backwards compatibility. + + Use `load_python_tools` to retrieve all tools defined in a .py file (returns a list). + This function will emit a `DeprecationWarning` and return the first discovered tool. + """ + warnings.warn( + "ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. " + "Use ToolLoader.load_python_tools(...) which always returns a list of AgentTool.", + DeprecationWarning, + stacklevel=2, + ) + + tools = ToolLoader.load_python_tools(tool_path, tool_name) + if not tools: + raise RuntimeError(f"No tools found in {tool_path} for {tool_name}") + return tools[0] + @classmethod def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool: - """Load a tool based on its file extension. + """DEPRECATED: Load a single tool based on its file extension for backwards compatibility. + + Use `load_tools` to retrieve all tools defined in a file (returns a list). + This function will emit a `DeprecationWarning` and return the first discovered tool. + """ + warnings.warn( + "ToolLoader.load_tool is deprecated and will be removed in Strands SDK 2.0. " + "Use ToolLoader.load_tools(...) which always returns a list of AgentTool.", + DeprecationWarning, + stacklevel=2, + ) + + tools = ToolLoader.load_tools(tool_path, tool_name) + if not tools: + raise RuntimeError(f"No tools found in {tool_path} for {tool_name}") + + return tools[0] + + @classmethod + def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]: + """Load tools from a file based on its file extension. Args: tool_path: Path to the tool file. tool_name: Name of the tool. Returns: - Tool instance. + A single Tool instance. Raises: FileNotFoundError: If the tool file does not exist. @@ -138,7 +162,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool: try: if ext == ".py": - return cls.load_python_tool(abs_path, tool_name) + return cls.load_python_tools(abs_path, tool_name) else: raise ValueError(f"Unsupported tool file type: {ext}") except Exception: diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 7cb03e46f..5d9dd0b0f 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -318,10 +318,12 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes """ self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content)) - mapped_content = [ - mapped_content + # Build a typed list of ToolResultContent. Use a clearer local name to avoid shadowing + # and annotate the result for mypy so it knows the intended element type. + mapped_contents: list[ToolResultContent] = [ + mc for content in call_tool_result.content - if (mapped_content := self._map_mcp_content_to_tool_result_content(content)) is not None + if (mc := self._map_mcp_content_to_tool_result_content(content)) is not None ] status: ToolResultStatus = "error" if call_tool_result.isError else "success" @@ -329,8 +331,9 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes result = MCPToolResult( status=status, toolUseId=tool_use_id, - content=mapped_content, + content=mapped_contents, ) + if call_tool_result.structuredContent: result["structuredContent"] = call_tool_result.structuredContent diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 471472a64..0660337a2 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -127,11 +127,11 @@ def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: if not os.path.exists(tool_path): raise FileNotFoundError(f"Tool file not found: {tool_path}") - loaded_tool = ToolLoader.load_tool(tool_path, tool_name) - loaded_tool.mark_dynamic() - - # Because we're explicitly registering the tool we don't need an allowlist - self.register_tool(loaded_tool) + loaded_tools = ToolLoader.load_tools(tool_path, tool_name) + for t in loaded_tools: + t.mark_dynamic() + # Because we're explicitly registering the tool we don't need an allowlist + self.register_tool(t) except Exception as e: exception_str = str(e) logger.exception("tool_name=<%s> | failed to load tool", tool_name) diff --git a/tests/strands/tools/test_loader.py b/tests/strands/tools/test_loader.py index c1b4d7040..6b86d00ee 100644 --- a/tests/strands/tools/test_loader.py +++ b/tests/strands/tools/test_loader.py @@ -235,3 +235,78 @@ def no_spec(): def test_load_tool_no_spec(tool_path): with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"): ToolLoader.load_tool(tool_path, "no_spec") + + with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"): + ToolLoader.load_tools(tool_path, "no_spec") + + with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"): + ToolLoader.load_python_tool(tool_path, "no_spec") + + with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"): + ToolLoader.load_python_tools(tool_path, "no_spec") + + +@pytest.mark.parametrize( + "tool_path", + [ + textwrap.dedent( + """ + import strands + + @strands.tools.tool + def alpha(): + return "alpha" + + @strands.tools.tool + def bravo(): + return "bravo" + """ + ) + ], + indirect=True, +) +def test_load_python_tool_path_multiple_function_based(tool_path): + # load_python_tools, load_tools returns a list when multiple decorated tools are present + loaded_python_tools = ToolLoader.load_python_tools(tool_path, "alpha") + + assert isinstance(loaded_python_tools, list) + assert len(loaded_python_tools) == 2 + assert all(isinstance(t, DecoratedFunctionTool) for t in loaded_python_tools) + names = {t.tool_name for t in loaded_python_tools} + assert names == {"alpha", "bravo"} + + loaded_tools = ToolLoader.load_tools(tool_path, "alpha") + + assert isinstance(loaded_tools, list) + assert len(loaded_tools) == 2 + assert all(isinstance(t, DecoratedFunctionTool) for t in loaded_tools) + names = {t.tool_name for t in loaded_tools} + assert names == {"alpha", "bravo"} + + +@pytest.mark.parametrize( + "tool_path", + [ + textwrap.dedent( + """ + import strands + + @strands.tools.tool + def alpha(): + return "alpha" + + @strands.tools.tool + def bravo(): + return "bravo" + """ + ) + ], + indirect=True, +) +def test_load_tool_path_returns_single_tool(tool_path): + # loaded_python_tool and loaded_tool returns single item + loaded_python_tool = ToolLoader.load_python_tool(tool_path, "alpha") + loaded_tool = ToolLoader.load_tool(tool_path, "alpha") + + assert loaded_python_tool.tool_name == "alpha" + assert loaded_tool.tool_name == "alpha" From d66fcdbf8b68432e91bdbab2c087342dc5f5e376 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 9 Sep 2025 09:43:43 -0400 Subject: [PATCH 083/104] fix(models): patch litellm bug to honor passing in use_litellm_proxy as client_args (#808) * fix(models): patch litellm bug to honor passing in use_litellm_proxy as client_args --------- Co-authored-by: Patrick Gray --- src/strands/models/litellm.py | 13 +++++++++++ tests/strands/models/test_litellm.py | 33 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 9a31e82df..36b385281 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -52,6 +52,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: self.client_args = client_args or {} validate_config_keys(model_config, self.LiteLLMConfig) self.config = dict(model_config) + self._apply_proxy_prefix() logger.debug("config=<%s> | initializing", self.config) @@ -64,6 +65,7 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: """ validate_config_keys(model_config, self.LiteLLMConfig) self.config.update(model_config) + self._apply_proxy_prefix() @override def get_config(self) -> LiteLLMConfig: @@ -226,3 +228,14 @@ async def structured_output( # If no tool_calls found, raise an error raise ValueError("No tool_calls found in response") + + def _apply_proxy_prefix(self) -> None: + """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. + + This is a workaround for https://github.com/BerriAI/litellm/issues/13454 + where use_litellm_proxy parameter is not honored. + """ + if self.client_args.get("use_litellm_proxy") and "model_id" in self.config: + model_id = self.get_config()["model_id"] + if not model_id.startswith("litellm_proxy/"): + self.config["model_id"] = f"litellm_proxy/{model_id}" diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 9140cadcc..4f9f48b92 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -58,6 +58,39 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id +@pytest.mark.parametrize( + "client_args, model_id, expected_model_id", + [ + ({"use_litellm_proxy": True}, "openai/gpt-4", "litellm_proxy/openai/gpt-4"), + ({"use_litellm_proxy": False}, "openai/gpt-4", "openai/gpt-4"), + ({"use_litellm_proxy": None}, "openai/gpt-4", "openai/gpt-4"), + ({}, "openai/gpt-4", "openai/gpt-4"), + (None, "openai/gpt-4", "openai/gpt-4"), + ({"use_litellm_proxy": True}, "litellm_proxy/openai/gpt-4", "litellm_proxy/openai/gpt-4"), + ({"use_litellm_proxy": False}, "litellm_proxy/openai/gpt-4", "litellm_proxy/openai/gpt-4"), + ], +) +def test__init__use_litellm_proxy_prefix(client_args, model_id, expected_model_id): + """Test litellm_proxy prefix behavior for various configurations.""" + model = LiteLLMModel(client_args=client_args, model_id=model_id) + assert model.get_config()["model_id"] == expected_model_id + + +@pytest.mark.parametrize( + "client_args, initial_model_id, new_model_id, expected_model_id", + [ + ({"use_litellm_proxy": True}, "openai/gpt-4", "anthropic/claude-3", "litellm_proxy/anthropic/claude-3"), + ({"use_litellm_proxy": False}, "openai/gpt-4", "anthropic/claude-3", "anthropic/claude-3"), + (None, "openai/gpt-4", "anthropic/claude-3", "anthropic/claude-3"), + ], +) +def test_update_config_proxy_prefix(client_args, initial_model_id, new_model_id, expected_model_id): + """Test that update_config applies proxy prefix correctly.""" + model = LiteLLMModel(client_args=client_args, model_id=initial_model_id) + model.update_config(model_id=new_model_id) + assert model.get_config()["model_id"] == expected_model_id + + @pytest.mark.parametrize( "content, exp_result", [ From 9213bc580824ff6b9f7dab48568c3376bc6a442d Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Tue, 9 Sep 2025 13:14:47 -0400 Subject: [PATCH 084/104] feat: add default read timeout to Bedrock model (#829) - Set DEFAULT_READ_TIMEOUT constant to 120 seconds - Configure BotocoreConfig with read_timeout when no custom config provided - Add test coverage for default read timeout behavior --- src/strands/models/bedrock.py | 3 ++- tests/strands/models/test_bedrock.py | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index f18422191..8909072f6 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -45,6 +45,7 @@ T = TypeVar("T", bound=BaseModel) +DEFAULT_READ_TIMEOUT = 120 class BedrockModel(Model): """AWS Bedrock model provider implementation. @@ -147,7 +148,7 @@ def __init__( client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) else: - client_config = BotocoreConfig(user_agent_extra="strands-agents") + client_config = BotocoreConfig(user_agent_extra="strands-agents", read_timeout=DEFAULT_READ_TIMEOUT) resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 624eec6e9..5e4c20e79 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -11,7 +11,7 @@ import strands from strands.models import BedrockModel -from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION +from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT from strands.types.exceptions import ModelThrottledException from strands.types.tools import ToolSpec @@ -216,6 +216,20 @@ def test__init__default_user_agent(bedrock_client): assert kwargs["service_name"] == "bedrock-runtime" assert isinstance(kwargs["config"], BotocoreConfig) assert kwargs["config"].user_agent_extra == "strands-agents" + assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT + + +def test__init__default_read_timeout(bedrock_client): + """Set default read timeout when no boto_client_config is provided.""" + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + _ = BedrockModel() + + # Verify the client was created with the correct read timeout + mock_session.client.assert_called_once() + args, kwargs = mock_session.client.call_args + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client): From 001aa937cb237a5b3fc8ad216c88c482ec52d074 Mon Sep 17 00:00:00 2001 From: Shang Liu <35161551+liushang1997@users.noreply.github.com> Date: Wed, 10 Sep 2025 08:43:07 -0700 Subject: [PATCH 085/104] feat: add support for Bedrock/Anthropic ToolChoice to structured_output (#720) For structured output so that some providers can force tool calls --------- Co-authored-by: Mackenzie Zastrow Co-authored-by: Shang Liu --- .../{_config_validation.py => _validation.py} | 15 ++ src/strands/models/anthropic.py | 38 ++++- src/strands/models/bedrock.py | 17 +- src/strands/models/litellm.py | 8 +- src/strands/models/llamaapi.py | 9 +- src/strands/models/mistral.py | 9 +- src/strands/models/model.py | 4 +- src/strands/models/ollama.py | 9 +- src/strands/models/openai.py | 40 ++++- src/strands/models/sagemaker.py | 18 +- src/strands/models/writer.py | 9 +- src/strands/types/tools.py | 11 +- tests/strands/models/test_anthropic.py | 81 +++++++++ tests/strands/models/test_bedrock.py | 66 ++++++++ tests/strands/models/test_litellm.py | 35 +++- tests/strands/models/test_llamaapi.py | 35 ++++ tests/strands/models/test_mistral.py | 37 ++++- tests/strands/models/test_ollama.py | 27 ++- tests/strands/models/test_openai.py | 156 ++++++++++++++++++ tests/strands/models/test_sagemaker.py | 26 ++- tests/strands/models/test_writer.py | 39 ++++- tests_integ/models/test_conformance.py | 36 +++- 22 files changed, 678 insertions(+), 47 deletions(-) rename src/strands/models/{_config_validation.py => _validation.py} (66%) diff --git a/src/strands/models/_config_validation.py b/src/strands/models/_validation.py similarity index 66% rename from src/strands/models/_config_validation.py rename to src/strands/models/_validation.py index 085449bb8..9eabe28a1 100644 --- a/src/strands/models/_config_validation.py +++ b/src/strands/models/_validation.py @@ -5,6 +5,8 @@ from typing_extensions import get_type_hints +from ..types.tools import ToolChoice + def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None: """Validate that config keys match the TypedDict fields. @@ -25,3 +27,16 @@ def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> f"\nSee https://github.com/strands-agents/sdk-python/issues/815", stacklevel=4, ) + + +def warn_on_tool_choice_not_supported(tool_choice: ToolChoice | None) -> None: + """Emits a warning if a tool choice is provided but not supported by the provider. + + Args: + tool_choice: the tool_choice provided to the provider + """ + if tool_choice: + warnings.warn( + "A ToolChoice was provided to this provider but is not supported and will be ignored", + stacklevel=4, + ) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 06dc816f2..4afc8e3dc 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -18,8 +18,8 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec +from ._validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -195,7 +195,11 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: return formatted_messages def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format an Anthropic streaming request. @@ -203,6 +207,7 @@ def format_request( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. Returns: An Anthropic streaming request. @@ -223,10 +228,25 @@ def format_request( } for tool_spec in tool_specs or [] ], + **(self._format_tool_choice(tool_choice)), **({"system": system_prompt} if system_prompt else {}), **(self.config.get("params") or {}), } + @staticmethod + def _format_tool_choice(tool_choice: ToolChoice | None) -> dict: + if tool_choice is None: + return {} + + if "any" in tool_choice: + return {"tool_choice": {"type": "any"}} + elif "auto" in tool_choice: + return {"tool_choice": {"type": "auto"}} + elif "tool" in tool_choice: + return {"tool_choice": {"type": "tool", "name": cast(ToolChoiceToolDict, tool_choice)["tool"]["name"]}} + else: + return {} + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: """Format the Anthropic response events into standardized message chunks. @@ -350,6 +370,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Anthropic model. @@ -358,6 +379,7 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -368,7 +390,7 @@ async def stream( ModelThrottledException: If the request is throttled by Anthropic. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) + request = self.format_request(messages, tool_specs, system_prompt, tool_choice) logger.debug("request=<%s>", request) logger.debug("invoking model") @@ -410,7 +432,13 @@ async def structured_output( """ tool_spec = convert_pydantic_to_tool_spec(output_model) - response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs) + response = self.stream( + messages=prompt, + tool_specs=[tool_spec], + system_prompt=system_prompt, + tool_choice=cast(ToolChoice, {"any": {}}), + **kwargs, + ) async for event in process_stream(response): yield event diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 8909072f6..9efd930d4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -23,8 +23,8 @@ ModelThrottledException, ) from ..types.streaming import CitationsDelta, StreamEvent -from ..types.tools import ToolResult, ToolSpec -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolResult, ToolSpec +from ._validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -196,6 +196,7 @@ def format_request( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format a Bedrock converse stream request. @@ -203,6 +204,7 @@ def format_request( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. Returns: A Bedrock converse stream request. @@ -225,7 +227,7 @@ def format_request( else [] ), ], - "toolChoice": {"auto": {}}, + **({"toolChoice": tool_choice if tool_choice else {"auto": {}}}), } } if tool_specs @@ -417,6 +419,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Bedrock model. @@ -428,6 +431,7 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -446,7 +450,7 @@ def callback(event: Optional[StreamEvent] = None) -> None: loop = asyncio.get_event_loop() queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue() - thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt) + thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt, tool_choice) task = asyncio.create_task(thread) while True: @@ -464,6 +468,7 @@ def _stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, ) -> None: """Stream conversation with the Bedrock model. @@ -475,6 +480,7 @@ def _stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. Raises: ContextWindowOverflowException: If the input exceeds the model's context window. @@ -482,7 +488,7 @@ def _stream( """ try: logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) + request = self.format_request(messages, tool_specs, system_prompt, tool_choice) logger.debug("request=<%s>", request) logger.debug("invoking model") @@ -739,6 +745,7 @@ async def structured_output( messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, + tool_choice=cast(ToolChoice, {"any": {}}), **kwargs, ) async for event in streaming.process_stream(response): diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 36b385281..6bcc1359e 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -14,8 +14,8 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolSpec +from ._validation import validate_config_keys from .openai import OpenAIModel logger = logging.getLogger(__name__) @@ -114,6 +114,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LiteLLM model. @@ -122,13 +123,14 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) + request = self.format_request(messages, tool_specs, system_prompt, tool_choice) logger.debug("request=<%s>", request) logger.debug("invoking model") diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 57ff85c66..4e801026c 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -18,8 +18,8 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent, Usage -from ..types.tools import ToolResult, ToolSpec, ToolUse -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -330,6 +330,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LlamaAPI model. @@ -338,6 +339,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -346,6 +349,8 @@ async def stream( Raises: ModelThrottledException: When the model service is throttling requests from the client. """ + warn_on_tool_choice_not_supported(tool_choice) + logger.debug("formatting request") request = self.format_request(messages, tool_specs, system_prompt) logger.debug("request=<%s>", request) diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 401dde98e..90cd1b5d8 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -15,8 +15,8 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StopReason, StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -397,6 +397,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Mistral model. @@ -405,6 +406,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -413,6 +416,8 @@ async def stream( Raises: ModelThrottledException: When the model service is throttling requests. """ + warn_on_tool_choice_not_supported(tool_choice) + logger.debug("formatting request") request = self.format_request(messages, tool_specs, system_prompt) logger.debug("request=<%s>", request) diff --git a/src/strands/models/model.py b/src/strands/models/model.py index cb24b704d..7a8b4d4cc 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -8,7 +8,7 @@ from ..types.content import Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec +from ..types.tools import ToolChoice, ToolSpec logger = logging.getLogger(__name__) @@ -70,6 +70,7 @@ def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncIterable[StreamEvent]: """Stream conversation with the model. @@ -84,6 +85,7 @@ def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 4025dc062..c29772215 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -13,8 +13,8 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StopReason, StreamEvent -from ..types.tools import ToolSpec -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolSpec +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -287,6 +287,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Ollama model. @@ -295,11 +296,15 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ + warn_on_tool_choice_not_supported(tool_choice) + logger.debug("formatting request") request = self.format_request(messages, tool_specs, system_prompt) logger.debug("request=<%s>", request) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 16eb4defe..fd75ea175 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -16,8 +16,8 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._validation import validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -174,6 +174,30 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: "content": [cls.format_request_message_content(content) for content in contents], } + @classmethod + def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str, Any]: + """Format a tool choice for OpenAI compatibility. + + Args: + tool_choice: Tool choice configuration in Bedrock format. + + Returns: + OpenAI compatible tool choice format. + """ + if not tool_choice: + return {} + + match tool_choice: + case {"auto": _}: + return {"tool_choice": "auto"} # OpenAI SDK doesn't define constants for these values + case {"any": _}: + return {"tool_choice": "required"} + case {"tool": {"name": tool_name}}: + return {"tool_choice": {"type": "function", "function": {"name": tool_name}}} + case _: + # This should not happen with proper typing, but handle gracefully + return {"tool_choice": "auto"} + @classmethod def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: """Format an OpenAI compatible messages array. @@ -216,7 +240,11 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format an OpenAI compatible chat streaming request. @@ -224,6 +252,7 @@ def format_request( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. Returns: An OpenAI compatible chat streaming request. @@ -248,6 +277,7 @@ def format_request( } for tool_spec in tool_specs or [] ], + **(self._format_request_tool_choice(tool_choice)), **cast(dict[str, Any], self.config.get("params", {})), } @@ -329,6 +359,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the OpenAI model. @@ -337,13 +368,14 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt) + request = self.format_request(messages, tool_specs, system_prompt, tool_choice) logger.debug("formatted request=<%s>", request) logger.debug("invoking model") diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 74069b895..f635acce2 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -14,8 +14,8 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolResult, ToolSpec +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported from .openai import OpenAIModel T = TypeVar("T", bound=BaseModel) @@ -197,7 +197,11 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i @override def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format an Amazon SageMaker chat streaming request. @@ -205,6 +209,8 @@ def format_request( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** Returns: An Amazon SageMaker chat streaming request. @@ -286,6 +292,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the SageMaker model. @@ -294,16 +301,21 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ + warn_on_tool_choice_not_supported(tool_choice) + logger.debug("formatting request") request = self.format_request(messages, tool_specs, system_prompt) logger.debug("formatted request=<%s>", request) logger.debug("invoking model") + try: if self.payload_config.get("stream", True): response = self.client.invoke_endpoint_with_response_stream(**request) diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index 9bcdaad42..07119a21a 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -16,8 +16,8 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse -from ._config_validation import validate_config_keys +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -355,6 +355,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Writer model. @@ -363,6 +364,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -371,6 +374,8 @@ async def stream( Raises: ModelThrottledException: When the model service is throttling requests from the client. """ + warn_on_tool_choice_not_supported(tool_choice) + logger.debug("formatting request") request = self.format_request(messages, tool_specs, system_prompt) logger.debug("request=<%s>", request) diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 1e0f4b841..e8d5531b2 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -145,10 +145,15 @@ class ToolContext: invocation_state: dict[str, Any] +# Individual ToolChoice type aliases +ToolChoiceAutoDict = dict[Literal["auto"], ToolChoiceAuto] +ToolChoiceAnyDict = dict[Literal["any"], ToolChoiceAny] +ToolChoiceToolDict = dict[Literal["tool"], ToolChoiceTool] + ToolChoice = Union[ - dict[Literal["auto"], ToolChoiceAuto], - dict[Literal["any"], ToolChoiceAny], - dict[Literal["tool"], ToolChoiceTool], + ToolChoiceAutoDict, + ToolChoiceAnyDict, + ToolChoiceToolDict, ] """ Configuration for how the model should choose tools. diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 9a7a4be11..74bbb8d45 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -417,6 +417,72 @@ def test_format_request_with_empty_content(model, model_id, max_tokens): assert tru_request == exp_request +def test_format_request_tool_choice_auto(model, messages, model_id, max_tokens): + tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] + tool_choice = {"auto": {}} + + tru_request = model.format_request(messages, tool_specs, tool_choice=tool_choice) + exp_request = { + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "tools": [ + { + "name": "test_tool", + "description": "test tool", + "input_schema": {"key": "value"}, + } + ], + "tool_choice": {"type": "auto"}, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_any(model, messages, model_id, max_tokens): + tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] + tool_choice = {"any": {}} + + tru_request = model.format_request(messages, tool_specs, tool_choice=tool_choice) + exp_request = { + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "tools": [ + { + "name": "test_tool", + "description": "test tool", + "input_schema": {"key": "value"}, + } + ], + "tool_choice": {"type": "any"}, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_tool(model, messages, model_id, max_tokens): + tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] + tool_choice = {"tool": {"name": "test_tool"}} + + tru_request = model.format_request(messages, tool_specs, tool_choice=tool_choice) + exp_request = { + "max_tokens": max_tokens, + "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], + "model": model_id, + "tools": [ + { + "name": "test_tool", + "description": "test tool", + "input_schema": {"key": "value"}, + } + ], + "tool_choice": {"name": "test_tool", "type": "tool"}, + } + + assert tru_request == exp_request + + def test_format_chunk_message_start(model): event = {"type": "message_start"} @@ -785,3 +851,18 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_tool_choice_supported_no_warning(model, messages, captured_warnings): + """Test that toolChoice doesn't emit warning for supported providers.""" + tool_choice = {"auto": {}} + model.format_request(messages, tool_choice=tool_choice) + + assert len(captured_warnings) == 0 + + +def test_tool_choice_none_no_warning(model, messages, captured_warnings): + """Test that None toolChoice doesn't emit warning.""" + model.format_request(messages, tool_choice=None) + + assert len(captured_warnings) == 0 diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 5e4c20e79..5ff4132d2 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -414,6 +414,57 @@ def test_format_request_tool_specs(model, messages, model_id, tool_spec): assert tru_request == exp_request +def test_format_request_tool_choice_auto(model, messages, model_id, tool_spec): + tool_choice = {"auto": {}} + tru_request = model.format_request(messages, [tool_spec], tool_choice=tool_choice) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": tool_choice, + }, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_any(model, messages, model_id, tool_spec): + tool_choice = {"any": {}} + tru_request = model.format_request(messages, [tool_spec], tool_choice=tool_choice) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": tool_choice, + }, + } + + assert tru_request == exp_request + + +def test_format_request_tool_choice_tool(model, messages, model_id, tool_spec): + tool_choice = {"tool": {"name": "test_tool"}} + tru_request = model.format_request(messages, [tool_spec], tool_choice=tool_choice) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"toolSpec": tool_spec}], + "toolChoice": tool_choice, + }, + } + + assert tru_request == exp_request + + def test_format_request_cache(model, messages, model_id, tool_spec, cache_type): model.update_config(cache_prompt=cache_type, cache_tools=cache_type) tru_request = model.format_request(messages, [tool_spec]) @@ -1477,3 +1528,18 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_tool_choice_supported_no_warning(model, messages, tool_spec, captured_warnings): + """Test that toolChoice doesn't emit warning for supported providers.""" + tool_choice = {"auto": {}} + model.format_request(messages, [tool_spec], tool_choice=tool_choice) + + assert len(captured_warnings) == 0 + + +def test_tool_choice_none_no_warning(model, messages, captured_warnings): + """Test that None toolChoice doesn't emit warning.""" + model.format_request(messages, tool_choice=None) + + assert len(captured_warnings) == 0 diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 4f9f48b92..f345ba003 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -1,4 +1,5 @@ import unittest.mock +from unittest.mock import call import pydantic import pytest @@ -219,15 +220,16 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, assert tru_events == exp_events - expected_request = { - "api_key": api_key, - "model": model_id, - "messages": [{"role": "user", "content": [{"text": "calculate 2+2", "type": "text"}]}], - "stream": True, - "stream_options": {"include_usage": True}, - "tools": [], - } - litellm_acompletion.assert_called_once_with(**expected_request) + assert litellm_acompletion.call_args_list == [ + call( + api_key=api_key, + messages=[{"role": "user", "content": [{"text": "calculate 2+2", "type": "text"}]}], + model=model_id, + stream=True, + stream_options={"include_usage": True}, + tools=[], + ) + ] @pytest.mark.asyncio @@ -303,3 +305,18 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_tool_choice_supported_no_warning(model, messages, captured_warnings): + """Test that toolChoice doesn't emit warning for supported providers.""" + tool_choice = {"auto": {}} + model.format_request(messages, tool_choice=tool_choice) + + assert len(captured_warnings) == 0 + + +def test_tool_choice_none_no_warning(model, messages, captured_warnings): + """Test that None toolChoice doesn't emit warning.""" + model.format_request(messages, tool_choice=None) + + assert len(captured_warnings) == 0 diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 712ef8b7a..a6bbf5673 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -379,3 +379,38 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +@pytest.mark.asyncio +async def test_tool_choice_not_supported_warns(model, messages, captured_warnings, alist): + """Test that non-None toolChoice emits warning for unsupported providers.""" + tool_choice = {"auto": {}} + + with unittest.mock.patch.object(model.client.chat.completions, "create") as mock_create: + mock_chunk = unittest.mock.Mock() + mock_chunk.event.event_type = "start" + mock_chunk.event.stop_reason = "stop" + + mock_create.return_value = [mock_chunk] + + response = model.stream(messages, tool_choice=tool_choice) + await alist(response) + + assert len(captured_warnings) == 1 + assert "ToolChoice was provided to this provider but is not supported" in str(captured_warnings[0].message) + + +@pytest.mark.asyncio +async def test_tool_choice_none_no_warning(model, messages, captured_warnings, alist): + """Test that None toolChoice doesn't emit warning.""" + with unittest.mock.patch.object(model.client.chat.completions, "create") as mock_create: + mock_chunk = unittest.mock.Mock() + mock_chunk.event.event_type = "start" + mock_chunk.event.stop_reason = "stop" + + mock_create.return_value = [mock_chunk] + + response = model.stream(messages, tool_choice=None) + await alist(response) + + assert len(captured_warnings) == 0 diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 9b3f62a31..7808336f2 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -437,7 +437,7 @@ def test_format_chunk_unknown(model): @pytest.mark.asyncio -async def test_stream(mistral_client, model, agenerator, alist): +async def test_stream(mistral_client, model, agenerator, alist, captured_warnings): mock_usage = unittest.mock.Mock() mock_usage.prompt_tokens = 100 mock_usage.completion_tokens = 50 @@ -472,6 +472,41 @@ async def test_stream(mistral_client, model, agenerator, alist): mistral_client.chat.stream_async.assert_called_once_with(**expected_request) + assert len(captured_warnings) == 0 + + +@pytest.mark.asyncio +async def test_tool_choice_not_supported_warns(mistral_client, model, agenerator, alist, captured_warnings): + tool_choice = {"auto": {}} + + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + + mock_event = unittest.mock.Mock( + data=unittest.mock.Mock( + choices=[ + unittest.mock.Mock( + delta=unittest.mock.Mock(content="test stream", tool_calls=None), + finish_reason="end_turn", + ) + ] + ), + usage=mock_usage, + ) + + mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages, None, None, tool_choice=tool_choice) + + # Consume the response + await alist(response) + + assert len(captured_warnings) == 1 + assert "ToolChoice was provided to this provider but is not supported" in str(captured_warnings[0].message) + @pytest.mark.asyncio async def test_stream_rate_limit_error(mistral_client, model, alist): diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index 9a63a3214..14db63a24 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -414,7 +414,7 @@ def test_format_chunk_other(model): @pytest.mark.asyncio -async def test_stream(ollama_client, model, agenerator, alist): +async def test_stream(ollama_client, model, agenerator, alist, captured_warnings): mock_event = unittest.mock.Mock() mock_event.message.tool_calls = None mock_event.message.content = "Hello" @@ -453,6 +453,31 @@ async def test_stream(ollama_client, model, agenerator, alist): } ollama_client.chat.assert_called_once_with(**expected_request) + # Ensure no warnings emitted + assert len(captured_warnings) == 0 + + +@pytest.mark.asyncio +async def test_tool_choice_not_supported_warns(ollama_client, model, agenerator, alist, captured_warnings): + """Test that non-None toolChoice emits warning for unsupported providers.""" + tool_choice = {"auto": {}} + + mock_event = unittest.mock.Mock() + mock_event.message.tool_calls = None + mock_event.message.content = "Hello" + mock_event.done_reason = "stop" + mock_event.eval_count = 10 + mock_event.prompt_eval_count = 5 + mock_event.total_duration = 1000000 # 1ms in nanoseconds + + ollama_client.chat = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + await alist(model.stream(messages, tool_choice=tool_choice)) + + assert len(captured_warnings) == 1 + assert "ToolChoice was provided to this provider but is not supported" in str(captured_warnings[0].message) + @pytest.mark.asyncio async def test_stream_with_tool_calls(ollama_client, model, agenerator, alist): diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 00cae7447..64da3cac2 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -179,6 +179,30 @@ def test_format_request_tool_message(): assert tru_result == exp_result +def test_format_request_tool_choice_auto(): + tool_choice = {"auto": {}} + + tru_result = OpenAIModel._format_request_tool_choice(tool_choice) + exp_result = {"tool_choice": "auto"} + assert tru_result == exp_result + + +def test_format_request_tool_choice_any(): + tool_choice = {"any": {}} + + tru_result = OpenAIModel._format_request_tool_choice(tool_choice) + exp_result = {"tool_choice": "required"} + assert tru_result == exp_result + + +def test_format_request_tool_choice_tool(): + tool_choice = {"tool": {"name": "test_tool"}} + + tru_result = OpenAIModel._format_request_tool_choice(tool_choice) + exp_result = {"tool_choice": {"type": "function", "function": {"name": "test_tool"}}} + assert tru_result == exp_result + + def test_format_request_messages(system_prompt): messages = [ { @@ -278,6 +302,123 @@ def test_format_request(model, messages, tool_specs, system_prompt): assert tru_request == exp_request +def test_format_request_with_tool_choice_auto(model, messages, tool_specs, system_prompt): + tool_choice = {"auto": {}} + tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "test", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "tool_choice": "auto", + "max_tokens": 1, + } + assert tru_request == exp_request + + +def test_format_request_with_tool_choice_any(model, messages, tool_specs, system_prompt): + tool_choice = {"any": {}} + tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "test", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "tool_choice": "required", + "max_tokens": 1, + } + assert tru_request == exp_request + + +def test_format_request_with_tool_choice_tool(model, messages, tool_specs, system_prompt): + tool_choice = {"tool": {"name": "test_tool"}} + tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "test", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "tool_choice": {"type": "function", "function": {"name": "test_tool"}}, + "max_tokens": 1, + } + assert tru_request == exp_request + + @pytest.mark.parametrize( ("event", "exp_chunk"), [ @@ -601,3 +742,18 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_tool_choice_supported_no_warning(model, messages, captured_warnings): + """Test that toolChoice doesn't emit warning for supported providers.""" + tool_choice = {"auto": {}} + model.format_request(messages, tool_choice=tool_choice) + + assert len(captured_warnings) == 0 + + +def test_tool_choice_none_no_warning(model, messages, captured_warnings): + """Test that None toolChoice doesn't emit warning.""" + model.format_request(messages, tool_choice=None) + + assert len(captured_warnings) == 0 diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index a9071c7e2..a5662ecdc 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -372,7 +372,7 @@ async def test_stream_with_tool_calls(self, sagemaker_client, model, messages): assert tool_use_data["name"] == "get_weather" @pytest.mark.asyncio - async def test_stream_with_partial_json(self, sagemaker_client, model, messages): + async def test_stream_with_partial_json(self, sagemaker_client, model, messages, captured_warnings): """Test streaming response with partial JSON chunks.""" # Mock the response from SageMaker with split JSON mock_response = { @@ -404,6 +404,30 @@ async def test_stream_with_partial_json(self, sagemaker_client, model, messages) text_delta = content_delta["contentBlockDelta"]["delta"]["text"] assert text_delta == "Paris is the capital of France." + # Ensure no warnings emitted + assert len(captured_warnings) == 0 + + @pytest.mark.asyncio + async def test_tool_choice_not_supported_warns(self, sagemaker_client, model, messages, captured_warnings, alist): + """Test that non-None toolChoice emits warning for unsupported providers.""" + tool_choice = {"auto": {}} + + """Test streaming response with partial JSON chunks.""" + # Mock the response from SageMaker with split JSON + mock_response = { + "Body": [ + {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}}, + {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}}, + ] + } + sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response + + await alist(model.stream(messages, tool_choice=tool_choice)) + + # Ensure toolChoice parameter warning + assert len(captured_warnings) == 1 + assert "ToolChoice was provided to this provider but is not supported" in str(captured_warnings[0].message) + @pytest.mark.asyncio async def test_stream_non_streaming(self, sagemaker_client, model, messages): """Test non-streaming response.""" diff --git a/tests/strands/models/test_writer.py b/tests/strands/models/test_writer.py index 75896ca68..8cf64a39a 100644 --- a/tests/strands/models/test_writer.py +++ b/tests/strands/models/test_writer.py @@ -353,7 +353,7 @@ async def test_stream_empty(writer_client, model, model_id): @pytest.mark.asyncio -async def test_stream_with_empty_choices(writer_client, model, model_id): +async def test_stream_with_empty_choices(writer_client, model, model_id, captured_warnings): mock_delta = unittest.mock.Mock(content="content", tool_calls=None) mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) @@ -381,6 +381,43 @@ async def test_stream_with_empty_choices(writer_client, model, model_id): } writer_client.chat.chat.assert_called_once_with(**expected_request) + # Ensure no warnings emitted + assert len(captured_warnings) == 0 + + +@pytest.mark.asyncio +async def test_tool_choice_not_supported_warns(writer_client, model, model_id, captured_warnings, alist): + mock_delta = unittest.mock.Mock(content="content", tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + + mock_event_1 = unittest.mock.Mock(spec=[]) + mock_event_2 = unittest.mock.Mock(choices=[]) + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_5 = unittest.mock.Mock(usage=mock_usage) + + writer_client.chat.chat.return_value = mock_streaming_response( + [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5] + ) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages, None, None, tool_choice={"auto": {}}) + + # Consume the response + await alist(response) + + expected_request = { + "model": model_id, + "messages": [{"role": "user", "content": [{"text": "test", "type": "text"}]}], + "stream": True, + "stream_options": {"include_usage": True}, + } + writer_client.chat.chat.assert_called_once_with(**expected_request) + + # Ensure expected warning is invoked + assert len(captured_warnings) == 1 + assert "ToolChoice was provided to this provider but is not supported" in str(captured_warnings[0].message) + def test_config_validation_warns_on_unknown_keys(writer_client, captured_warnings): """Test that unknown config keys emit a warning.""" diff --git a/tests_integ/models/test_conformance.py b/tests_integ/models/test_conformance.py index d9875bc07..eaef1eb88 100644 --- a/tests_integ/models/test_conformance.py +++ b/tests_integ/models/test_conformance.py @@ -1,7 +1,11 @@ +from unittest import SkipTest + import pytest +from pydantic import BaseModel +from strands import Agent from strands.models import Model -from tests_integ.models.providers import ProviderInfo, all_providers +from tests_integ.models.providers import ProviderInfo, all_providers, cohere, llama, mistral def get_models(): @@ -20,11 +24,39 @@ def provider_info(request) -> ProviderInfo: return request.param +@pytest.fixture() +def skip_for(provider_info: list[ProviderInfo]): + """A fixture which provides a function to skip the test if the provider is one of the providers specified.""" + + def skip_for_any_provider_in_list(providers: list[ProviderInfo], description: str): + """Skips the current test is the provider is one of those provided.""" + if provider_info in providers: + raise SkipTest(f"Skipping test for {provider_info.id}: {description}") + + return skip_for_any_provider_in_list + + @pytest.fixture() def model(provider_info): return provider_info.create_model() -def test_model_can_be_constructed(model: Model): +def test_model_can_be_constructed(model: Model, skip_for): assert model is not None pass + + +def test_structured_output_is_forced(skip_for, model): + """Tests that structured_output is always forced to return a value even if model doesn't have any information.""" + skip_for([mistral, cohere, llama], "structured_output is not forced for provider ") + + class Weather(BaseModel): + time: str + weather: str + + agent = Agent(model) + + result = agent.structured_output(Weather, "How are you?") + + assert len(result.time) > 0 + assert len(result.weather) > 0 From 7f58ce9f3bade6956841abb82cbfbe29289430f4 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 10 Sep 2025 13:11:03 -0400 Subject: [PATCH 086/104] feat(multiagent): allow callers of swarm and graph to pass kwargs to executors (#816) * feat(multiagent): allow callers of swarm and graph to pass kwargs to executors --------- Co-authored-by: Nick Clegg Co-authored-by: Aditya Bhushan Sharma --- src/strands/multiagent/base.py | 30 ++++++++++++--- src/strands/multiagent/graph.py | 50 ++++++++++++++++++------- src/strands/multiagent/swarm.py | 47 ++++++++++++++++++----- tests/strands/multiagent/test_base.py | 2 +- tests/strands/multiagent/test_graph.py | 52 ++++++++++++++++++++++++++ tests/strands/multiagent/test_swarm.py | 29 ++++++++++++++ 6 files changed, 181 insertions(+), 29 deletions(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 69578cb5d..03d7de9b4 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -84,15 +84,35 @@ class MultiAgentBase(ABC): """ @abstractmethod - async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult: - """Invoke asynchronously.""" + async def invoke_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> MultiAgentResult: + """Invoke asynchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Additional keyword arguments passed to underlying agents. + """ raise NotImplementedError("invoke_async not implemented") - def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult: - """Invoke synchronously.""" + def __call__( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> MultiAgentResult: + """Invoke synchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Additional keyword arguments passed to underlying agents. + """ + if invocation_state is None: + invocation_state = {} def execute() -> MultiAgentResult: - return asyncio.run(self.invoke_async(task, **kwargs)) + return asyncio.run(self.invoke_async(task, invocation_state, **kwargs)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index d2838396d..738dc4d4c 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -385,18 +385,42 @@ def __init__( self.state = GraphState() self.tracer = get_tracer() - def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult: - """Invoke the graph synchronously.""" + def __call__( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> GraphResult: + """Invoke the graph synchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Keyword arguments allowing backward compatible future changes. + """ + if invocation_state is None: + invocation_state = {} def execute() -> GraphResult: - return asyncio.run(self.invoke_async(task)) + return asyncio.run(self.invoke_async(task, invocation_state)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() - async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult: - """Invoke the graph asynchronously.""" + async def invoke_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> GraphResult: + """Invoke the graph asynchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues - a new empty dict + is created if None is provided. + **kwargs: Keyword arguments allowing backward compatible future changes. + """ + if invocation_state is None: + invocation_state = {} + logger.debug("task=<%s> | starting graph execution", task) # Initialize state @@ -420,7 +444,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> G self.node_timeout or "None", ) - await self._execute_graph() + await self._execute_graph(invocation_state) # Set final status based on execution results if self.state.failed_nodes: @@ -450,7 +474,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - async def _execute_graph(self) -> None: + async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: """Unified execution flow with conditional routing.""" ready_nodes = list(self.entry_points) @@ -469,7 +493,7 @@ async def _execute_graph(self) -> None: ready_nodes.clear() # Execute current batch of ready nodes concurrently - tasks = [asyncio.create_task(self._execute_node(node)) for node in current_batch] + tasks = [asyncio.create_task(self._execute_node(node, invocation_state)) for node in current_batch] for task in tasks: await task @@ -506,7 +530,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ ) return False - async def _execute_node(self, node: GraphNode) -> None: + async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> None: """Execute a single node with error handling and timeout protection.""" # Reset the node's state if reset_on_revisit is enabled and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: @@ -529,11 +553,11 @@ async def _execute_node(self, node: GraphNode) -> None: if isinstance(node.executor, MultiAgentBase): if self.node_timeout is not None: multi_agent_result = await asyncio.wait_for( - node.executor.invoke_async(node_input), + node.executor.invoke_async(node_input, invocation_state), timeout=self.node_timeout, ) else: - multi_agent_result = await node.executor.invoke_async(node_input) + multi_agent_result = await node.executor.invoke_async(node_input, invocation_state) # Create NodeResult with MultiAgentResult directly node_result = NodeResult( @@ -548,11 +572,11 @@ async def _execute_node(self, node: GraphNode) -> None: elif isinstance(node.executor, Agent): if self.node_timeout is not None: agent_response = await asyncio.wait_for( - node.executor.invoke_async(node_input), + node.executor.invoke_async(node_input, **invocation_state), timeout=self.node_timeout, ) else: - agent_response = await node.executor.invoke_async(node_input) + agent_response = await node.executor.invoke_async(node_input, **invocation_state) # Extract metrics from agent response usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index d730d5156..1c2302c28 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -237,18 +237,42 @@ def __init__( self._setup_swarm(nodes) self._inject_swarm_tools() - def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult: - """Invoke the swarm synchronously.""" + def __call__( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> SwarmResult: + """Invoke the swarm synchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Keyword arguments allowing backward compatible future changes. + """ + if invocation_state is None: + invocation_state = {} def execute() -> SwarmResult: - return asyncio.run(self.invoke_async(task)) + return asyncio.run(self.invoke_async(task, invocation_state)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() - async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult: - """Invoke the swarm asynchronously.""" + async def invoke_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> SwarmResult: + """Invoke the swarm asynchronously. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues - a new empty dict + is created if None is provided. + **kwargs: Keyword arguments allowing backward compatible future changes. + """ + if invocation_state is None: + invocation_state = {} + logger.debug("starting swarm execution") # Initialize swarm state with configuration @@ -272,7 +296,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> S self.execution_timeout, ) - await self._execute_swarm() + await self._execute_swarm(invocation_state) except Exception: logger.exception("swarm execution failed") self.state.completion_status = Status.FAILED @@ -483,7 +507,7 @@ def _build_node_input(self, target_node: SwarmNode) -> str: return context_text - async def _execute_swarm(self) -> None: + async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: """Shared execution logic used by execute_async.""" try: # Main execution loop @@ -522,7 +546,7 @@ async def _execute_swarm(self) -> None: # TODO: Implement cancellation token to stop _execute_node from continuing try: await asyncio.wait_for( - self._execute_node(current_node, self.state.task), + self._execute_node(current_node, self.state.task, invocation_state), timeout=self.node_timeout, ) @@ -563,7 +587,9 @@ async def _execute_swarm(self) -> None: f"{elapsed_time:.2f}", ) - async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -> AgentResult: + async def _execute_node( + self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] + ) -> AgentResult: """Execute swarm node.""" start_time = time.time() node_name = node.node_id @@ -583,7 +609,8 @@ async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) - # Execute node result = None node.reset_executor_state() - result = await node.executor.invoke_async(node_input) + # Unpacking since this is the agent class. Other executors should not unpack + result = await node.executor.invoke_async(node_input, **invocation_state) execution_time = round((time.time() - start_time) * 1000) diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 395d9275c..d21aa6e14 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -155,7 +155,7 @@ def __init__(self): self.received_task = None self.received_kwargs = None - async def invoke_async(self, task, **kwargs): + async def invoke_async(self, task, invocation_state, **kwargs): self.invoke_async_called = True self.received_task = task self.received_kwargs = kwargs diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 1a598847d..8097d944e 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1285,3 +1285,55 @@ def multi_loop_condition(state: GraphState) -> bool: assert result.status == Status.COMPLETED assert len(result.execution_order) >= 2 assert multi_agent.invoke_async.call_count >= 2 + + +@pytest.mark.asyncio +async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying Agent nodes.""" + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) + + builder = GraphBuilder() + builder.add_node(kwargs_agent, "kwargs_node") + graph = builder.build() + + test_invocation_state = {"custom_param": "test_value", "another_param": 42} + result = await graph.invoke_async("Test kwargs passing", test_invocation_state) + + kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing"}], **test_invocation_state) + assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_graph_kwargs_passing_multiagent(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying MultiAgentBase nodes.""" + kwargs_multiagent = create_mock_multi_agent("kwargs_multiagent", "MultiAgent response with kwargs") + kwargs_multiagent.invoke_async = Mock(side_effect=kwargs_multiagent.invoke_async) + + builder = GraphBuilder() + builder.add_node(kwargs_multiagent, "multiagent_node") + graph = builder.build() + + test_invocation_state = {"custom_param": "test_value", "another_param": 42} + result = await graph.invoke_async("Test kwargs passing to multiagent", test_invocation_state) + + kwargs_multiagent.invoke_async.assert_called_once_with( + [{"text": "Test kwargs passing to multiagent"}], test_invocation_state + ) + assert result.status == Status.COMPLETED + + +def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying nodes in sync execution.""" + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) + + builder = GraphBuilder() + builder.add_node(kwargs_agent, "kwargs_node") + graph = builder.build() + + test_invocation_state = {"custom_param": "test_value", "another_param": 42} + result = graph("Test kwargs passing sync", test_invocation_state) + + kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_invocation_state) + assert result.status == Status.COMPLETED diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 74f89241f..be463c7fd 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -469,3 +469,32 @@ def test_swarm_validate_unsupported_features(): with pytest.raises(ValueError, match="Session persistence is not supported for Swarm agents yet"): Swarm([agent_with_session]) + + +@pytest.mark.asyncio +async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying agents.""" + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) + + swarm = Swarm(nodes=[kwargs_agent]) + + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = await swarm.invoke_async("Test kwargs passing", test_kwargs) + + assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs + assert result.status == Status.COMPLETED + + +def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): + """Test that kwargs are passed through to underlying agents in sync execution.""" + kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") + kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) + + swarm = Swarm(nodes=[kwargs_agent]) + + test_kwargs = {"custom_param": "test_value", "another_param": 42} + result = swarm("Test kwargs passing sync", test_kwargs) + + assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs + assert result.status == Status.COMPLETED From 64d61e03cbda95fb2cc00109a78c92330fcf454e Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:18:23 -0400 Subject: [PATCH 087/104] feat: add region-aware default model ID for Bedrock (#835) These changes introduce region-aware default model ID functionality for Bedrock, formatting based on region prefixes, warnings for unsupported regions, and preservation of custom model IDs. Comprehensive test coverage was added, and existing tests were updated. We also maintain compatibility for two key use cases: preserving customer-overridden model IDs and maintaining compatibility with existing DEFAULT_BEDROCK_MODEL_ID usage patterns. --- src/strands/models/bedrock.py | 58 +++++++++++++++-- tests/strands/agent/test_agent.py | 7 +- tests/strands/models/test_bedrock.py | 96 +++++++++++++++++++++++++++- 3 files changed, 152 insertions(+), 9 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9efd930d4..ba1c77193 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -7,6 +7,7 @@ import json import logging import os +import warnings from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast import boto3 @@ -29,7 +30,9 @@ logger = logging.getLogger(__name__) +# See: `BedrockModel._get_default_model_with_warning` for why we need both DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0" +_DEFAULT_BEDROCK_MODEL_ID = "{}.anthropic.claude-sonnet-4-20250514-v1:0" DEFAULT_BEDROCK_REGION = "us-west-2" BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [ @@ -47,6 +50,7 @@ DEFAULT_READ_TIMEOUT = 120 + class BedrockModel(Model): """AWS Bedrock model provider implementation. @@ -129,13 +133,16 @@ def __init__( if region_name and boto_session: raise ValueError("Cannot specify both `region_name` and `boto_session`.") - self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID, include_tool_result_status="auto") + session = boto_session or boto3.Session() + resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION + self.config = BedrockModel.BedrockConfig( + model_id=BedrockModel._get_default_model_with_warning(resolved_region, model_config), + include_tool_result_status="auto", + ) self.update_config(**model_config) logger.debug("config=<%s> | initializing", self.config) - session = boto_session or boto3.Session() - # Add strands-agents to the request user agent if boto_client_config: existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) @@ -150,8 +157,6 @@ def __init__( else: client_config = BotocoreConfig(user_agent_extra="strands-agents", read_timeout=DEFAULT_READ_TIMEOUT) - resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION - self.client = session.client( service_name="bedrock-runtime", config=client_config, @@ -770,3 +775,46 @@ async def structured_output( raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") yield {"output": output_model(**output_response)} + + @staticmethod + def _get_default_model_with_warning(region_name: str, model_config: Optional[BedrockConfig] = None) -> str: + """Get the default Bedrock modelId based on region. + + If the region is not **known** to support inference then we show a helpful warning + that compliments the exception that Bedrock will throw. + If the customer provided a model_id in their config or they overrode the `DEFAULT_BEDROCK_MODEL_ID` + then we should not process further. + + Args: + region_name (str): region for bedrock model + model_config (Optional[dict[str, Any]]): Model Config that caller passes in on init + """ + if DEFAULT_BEDROCK_MODEL_ID != _DEFAULT_BEDROCK_MODEL_ID.format("us"): + return DEFAULT_BEDROCK_MODEL_ID + + model_config = model_config or {} + if model_config.get("model_id"): + return model_config["model_id"] + + prefix_inference_map = {"ap": "apac"} # some inference endpoints can be a bit different than the region prefix + + prefix = "-".join(region_name.split("-")[:-2]).lower() # handles `us-east-1` or `us-gov-east-1` + if prefix not in {"us", "eu", "ap", "us-gov"}: + warnings.warn( + f""" + ================== WARNING ================== + + This region {region_name} does not support + our default inference endpoint: {_DEFAULT_BEDROCK_MODEL_ID.format(prefix)}. + Update the agent to pass in a 'model_id' like so: + ``` + Agent(..., model='valid_model_id', ...) + ```` + Documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html + + ================================================== + """, + stacklevel=2, + ) + + return _DEFAULT_BEDROCK_MODEL_ID.format(prefix_inference_map.get(prefix, prefix)) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index a8561abe4..2cd87c26d 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -26,6 +26,9 @@ from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider +# For unit testing we will use the the us inference +FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us") + @pytest.fixture def mock_randint(): @@ -211,7 +214,7 @@ def test_agent__init__with_default_model(): agent = Agent() assert isinstance(agent.model, BedrockModel) - assert agent.model.config["model_id"] == DEFAULT_BEDROCK_MODEL_ID + assert agent.model.config["model_id"] == FORMATTED_DEFAULT_MODEL_ID def test_agent__init__with_explicit_model(mock_model): @@ -891,7 +894,7 @@ def test_agent__del__(agent): def test_agent_init_with_no_model_or_model_id(): agent = Agent() assert agent.model is not None - assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID + assert agent.model.get_config().get("model_id") == FORMATTED_DEFAULT_MODEL_ID def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, agenerator): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 5ff4132d2..e9bea2686 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -11,10 +11,17 @@ import strands from strands.models import BedrockModel -from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT +from strands.models.bedrock import ( + _DEFAULT_BEDROCK_MODEL_ID, + DEFAULT_BEDROCK_MODEL_ID, + DEFAULT_BEDROCK_REGION, + DEFAULT_READ_TIMEOUT, +) from strands.types.exceptions import ModelThrottledException from strands.types.tools import ToolSpec +FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us") + @pytest.fixture def session_cls(): @@ -119,7 +126,7 @@ def test__init__default_model_id(bedrock_client): model = BedrockModel() tru_model_id = model.get_config().get("model_id") - exp_model_id = DEFAULT_BEDROCK_MODEL_ID + exp_model_id = FORMATTED_DEFAULT_MODEL_ID assert tru_model_id == exp_model_id @@ -1543,3 +1550,88 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): model.format_request(messages, tool_choice=None) assert len(captured_warnings) == 0 + + +def test_get_default_model_with_warning_supported_regions_shows_no_warning(captured_warnings): + """Test get_model_prefix_with_warning doesn't warn for supported region prefixes.""" + BedrockModel._get_default_model_with_warning("us-west-2") + BedrockModel._get_default_model_with_warning("eu-west-2") + assert len(captured_warnings) == 0 + + +def test_get_default_model_for_supported_eu_region_returns_correct_model_id(captured_warnings): + model_id = BedrockModel._get_default_model_with_warning("eu-west-1") + assert model_id == "eu.anthropic.claude-sonnet-4-20250514-v1:0" + assert len(captured_warnings) == 0 + + +def test_get_default_model_for_supported_us_region_returns_correct_model_id(captured_warnings): + model_id = BedrockModel._get_default_model_with_warning("us-east-1") + assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0" + assert len(captured_warnings) == 0 + + +def test_get_default_model_for_supported_gov_region_returns_correct_model_id(captured_warnings): + model_id = BedrockModel._get_default_model_with_warning("us-gov-west-1") + assert model_id == "us-gov.anthropic.claude-sonnet-4-20250514-v1:0" + assert len(captured_warnings) == 0 + + +def test_get_model_prefix_for_ap_region_converts_to_apac_endpoint(captured_warnings): + """Test _get_default_model_with_warning warns for APAC regions since 'ap' is not in supported prefixes.""" + model_id = BedrockModel._get_default_model_with_warning("ap-southeast-1") + assert model_id == "apac.anthropic.claude-sonnet-4-20250514-v1:0" + + +def test_get_default_model_with_warning_unsupported_region_warns(captured_warnings): + """Test _get_default_model_with_warning warns for unsupported regions.""" + BedrockModel._get_default_model_with_warning("ca-central-1") + assert len(captured_warnings) == 1 + assert "This region ca-central-1 does not support" in str(captured_warnings[0].message) + assert "our default inference endpoint" in str(captured_warnings[0].message) + + +def test_get_default_model_with_warning_no_warning_with_custom_model_id(captured_warnings): + """Test _get_default_model_with_warning doesn't warn when custom model_id provided.""" + model_config = {"model_id": "custom-model"} + model_id = BedrockModel._get_default_model_with_warning("ca-central-1", model_config) + + assert model_id == "custom-model" + assert len(captured_warnings) == 0 + + +def test_init_with_unsupported_region_warns(session_cls, captured_warnings): + """Test BedrockModel initialization warns for unsupported regions.""" + BedrockModel(region_name="ca-central-1") + + assert len(captured_warnings) == 1 + assert "This region ca-central-1 does not support" in str(captured_warnings[0].message) + + +def test_init_with_unsupported_region_custom_model_no_warning(session_cls, captured_warnings): + """Test BedrockModel initialization doesn't warn when custom model_id provided.""" + BedrockModel(region_name="ca-central-1", model_id="custom-model") + assert len(captured_warnings) == 0 + + +def test_override_default_model_id_uses_the_overriden_value(captured_warnings): + with unittest.mock.patch("strands.models.bedrock.DEFAULT_BEDROCK_MODEL_ID", "custom-overridden-model"): + model_id = BedrockModel._get_default_model_with_warning("us-east-1") + assert model_id == "custom-overridden-model" + + +def test_no_override_uses_formatted_default_model_id(captured_warnings): + model_id = BedrockModel._get_default_model_with_warning("us-east-1") + assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0" + assert model_id != _DEFAULT_BEDROCK_MODEL_ID + assert len(captured_warnings) == 0 + + +def test_custom_model_id_not_overridden_by_region_formatting(session_cls): + """Test that custom model_id is not overridden by region formatting.""" + custom_model_id = "custom.model.id" + + model = BedrockModel(model_id=custom_model_id) + model_id = model.get_config().get("model_id") + + assert model_id == custom_model_id From ab125f5b35aefffaebe8e331e53ecd711047d97f Mon Sep 17 00:00:00 2001 From: Aaron Brown <47581657+westonbrown@users.noreply.github.com> Date: Wed, 10 Sep 2025 14:26:37 -0500 Subject: [PATCH 088/104] llama.cpp model provider support (#585) --- README.md | 2 + src/strands/models/llamacpp.py | 762 ++++++++++++++++++++++ tests/strands/models/test_llamacpp.py | 639 ++++++++++++++++++ tests_integ/models/test_model_llamacpp.py | 510 +++++++++++++++ 4 files changed, 1913 insertions(+) create mode 100644 src/strands/models/llamacpp.py create mode 100644 tests/strands/models/test_llamacpp.py create mode 100644 tests_integ/models/test_model_llamacpp.py diff --git a/README.md b/README.md index 62ed54d47..44d10b67e 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,7 @@ from strands import Agent from strands.models import BedrockModel from strands.models.ollama import OllamaModel from strands.models.llamaapi import LlamaAPIModel +from strands.models.llamacpp import LlamaCppModel # Bedrock bedrock_model = BedrockModel( @@ -159,6 +160,7 @@ Built-in providers: - [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/) - [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/) - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) + - [llama.cpp](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamacpp/) - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py new file mode 100644 index 000000000..94a225a06 --- /dev/null +++ b/src/strands/models/llamacpp.py @@ -0,0 +1,762 @@ +"""llama.cpp model provider. + +Provides integration with llama.cpp servers running in OpenAI-compatible mode, +with support for advanced llama.cpp-specific features. + +- Docs: https://github.com/ggml-org/llama.cpp +- Server docs: https://github.com/ggml-org/llama.cpp/tree/master/tools/server +- OpenAI API compatibility: + https://github.com/ggml-org/llama.cpp/blob/master/tools/server/README.md#api-endpoints +""" + +import base64 +import json +import logging +import mimetypes +import time +from typing import ( + Any, + AsyncGenerator, + Dict, + Optional, + Type, + TypedDict, + TypeVar, + Union, + cast, +) + +import httpx +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import ContentBlock, Messages +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.streaming import StreamEvent +from ..types.tools import ToolSpec +from .model import Model + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class LlamaCppModel(Model): + """llama.cpp model provider implementation. + + Connects to a llama.cpp server running in OpenAI-compatible mode with + support for advanced llama.cpp-specific features like grammar constraints, + Mirostat sampling, native JSON schema validation, and native multimodal + support for audio and image content. + + The llama.cpp server must be started with the OpenAI-compatible API enabled: + llama-server -m model.gguf --host 0.0.0.0 --port 8080 + + Example: + Basic usage: + >>> model = LlamaCppModel(base_url="http://localhost:8080") + >>> model.update_config(params={"temperature": 0.7, "top_k": 40}) + + Grammar constraints via params: + >>> model.update_config(params={ + ... "grammar": ''' + ... root ::= answer + ... answer ::= "yes" | "no" + ... ''' + ... }) + + Advanced sampling: + >>> model.update_config(params={ + ... "mirostat": 2, + ... "mirostat_lr": 0.1, + ... "tfs_z": 0.95, + ... "repeat_penalty": 1.1 + ... }) + + Multimodal usage (requires multimodal model like Qwen2.5-Omni): + >>> # Audio analysis + >>> audio_content = [{ + ... "audio": {"source": {"bytes": audio_bytes}, "format": "wav"}, + ... "text": "What do you hear in this audio?" + ... }] + >>> response = agent(audio_content) + + >>> # Image analysis + >>> image_content = [{ + ... "image": {"source": {"bytes": image_bytes}, "format": "png"}, + ... "text": "Describe this image" + ... }] + >>> response = agent(image_content) + """ + + class LlamaCppConfig(TypedDict, total=False): + """Configuration options for llama.cpp models. + + Attributes: + model_id: Model identifier for the loaded model in llama.cpp server. + Default is "default" as llama.cpp typically loads a single model. + params: Model parameters supporting both OpenAI and llama.cpp-specific options. + + OpenAI-compatible parameters: + - max_tokens: Maximum number of tokens to generate + - temperature: Sampling temperature (0.0 to 2.0) + - top_p: Nucleus sampling parameter (0.0 to 1.0) + - frequency_penalty: Frequency penalty (-2.0 to 2.0) + - presence_penalty: Presence penalty (-2.0 to 2.0) + - stop: List of stop sequences + - seed: Random seed for reproducibility + - n: Number of completions to generate + - logprobs: Include log probabilities in output + - top_logprobs: Number of top log probabilities to include + + llama.cpp-specific parameters: + - repeat_penalty: Penalize repeat tokens (1.0 = no penalty) + - top_k: Top-k sampling (0 = disabled) + - min_p: Min-p sampling threshold (0.0 to 1.0) + - typical_p: Typical-p sampling (0.0 to 1.0) + - tfs_z: Tail-free sampling parameter (0.0 to 1.0) + - top_a: Top-a sampling parameter + - mirostat: Mirostat sampling mode (0, 1, or 2) + - mirostat_lr: Mirostat learning rate + - mirostat_ent: Mirostat target entropy + - grammar: GBNF grammar string for constrained generation + - json_schema: JSON schema for structured output + - penalty_last_n: Number of tokens to consider for penalties + - n_probs: Number of probabilities to return per token + - min_keep: Minimum tokens to keep in sampling + - ignore_eos: Ignore end-of-sequence token + - logit_bias: Token ID to bias mapping + - cache_prompt: Cache the prompt for faster generation + - slot_id: Slot ID for parallel inference + - samplers: Custom sampler order + """ + + model_id: str + params: Optional[dict[str, Any]] + + def __init__( + self, + base_url: str = "http://localhost:8080", + timeout: Optional[Union[float, tuple[float, float]]] = None, + **model_config: Unpack[LlamaCppConfig], + ) -> None: + """Initialize llama.cpp provider instance. + + Args: + base_url: Base URL for the llama.cpp server. + Default is "http://localhost:8080" for local server. + timeout: Request timeout in seconds. Can be float or tuple of + (connect, read) timeouts. + **model_config: Configuration options for the llama.cpp model. + """ + # Set default model_id if not provided + if "model_id" not in model_config: + model_config["model_id"] = "default" + + self.base_url = base_url.rstrip("/") + self.config = dict(model_config) + + # Configure HTTP client + if isinstance(timeout, tuple): + # Convert tuple to httpx.Timeout object + timeout_obj = httpx.Timeout( + connect=timeout[0] if len(timeout) > 0 else None, + read=timeout[1] if len(timeout) > 1 else None, + write=timeout[2] if len(timeout) > 2 else None, + pool=timeout[3] if len(timeout) > 3 else None, + ) + else: + timeout_obj = httpx.Timeout(timeout or 30.0) + + self.client = httpx.AsyncClient( + base_url=self.base_url, + timeout=timeout_obj, + ) + + logger.debug( + "base_url=<%s>, model_id=<%s> | initializing llama.cpp provider", + base_url, + model_config.get("model_id"), + ) + + @override + def update_config(self, **model_config: Unpack[LlamaCppConfig]) -> None: # type: ignore[override] + """Update the llama.cpp model configuration with provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> LlamaCppConfig: + """Get the llama.cpp model configuration. + + Returns: + The llama.cpp model configuration. + """ + return self.config # type: ignore[return-value] + + def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) -> dict[str, Any]: + """Format a content block for llama.cpp. + + Args: + content: Message content. + + Returns: + llama.cpp compatible content block. + + Raises: + TypeError: If the content block type cannot be converted to a compatible format. + """ + if "document" in content: + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") + return { + "file": { + "file_data": f"data:{mime_type};base64,{file_data}", + "filename": content["document"]["name"], + }, + "type": "file", + } + + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + return { + "image_url": { + "detail": "auto", + "format": mime_type, + "url": f"data:{mime_type};base64,{image_data}", + }, + "type": "image_url", + } + + # Handle audio content (not in standard ContentBlock but supported by llama.cpp) + if "audio" in content: + audio_content = cast(Dict[str, Any], content) + audio_data = base64.b64encode(audio_content["audio"]["source"]["bytes"]).decode("utf-8") + audio_format = audio_content["audio"].get("format", "wav") + return { + "type": "input_audio", + "input_audio": {"data": audio_data, "format": audio_format}, + } + + if "text" in content: + return {"text": content["text"], "type": "text"} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + def _format_tool_call(self, tool_use: dict[str, Any]) -> dict[str, Any]: + """Format a tool call for llama.cpp. + + Args: + tool_use: Tool use requested by the model. + + Returns: + llama.cpp compatible tool call. + """ + return { + "function": { + "arguments": json.dumps(tool_use["input"]), + "name": tool_use["name"], + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + def _format_tool_message(self, tool_result: dict[str, Any]) -> dict[str, Any]: + """Format a tool message for llama.cpp. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + llama.cpp compatible tool message. + """ + contents = [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ] + + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": [self._format_message_content(content) for content in contents], + } + + def _format_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format messages for llama.cpp. + + Args: + messages: List of message objects to be processed. + system_prompt: System prompt to provide context to the model. + + Returns: + Formatted messages array compatible with llama.cpp. + """ + formatted_messages: list[dict[str, Any]] = [] + + # Add system prompt if provided + if system_prompt: + formatted_messages.append({"role": "system", "content": system_prompt}) + + for message in messages: + contents = message["content"] + + formatted_contents = [ + self._format_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + formatted_tool_calls = [ + self._format_tool_call( + { + "name": content["toolUse"]["name"], + "input": content["toolUse"]["input"], + "toolUseId": content["toolUse"]["toolUseId"], + } + ) + for content in contents + if "toolUse" in content + ] + formatted_tool_messages = [ + self._format_tool_message( + { + "toolUseId": content["toolResult"]["toolUseId"], + "content": content["toolResult"]["content"], + } + ) + for content in contents + if "toolResult" in content + ] + + formatted_message = { + "role": message["role"], + "content": formatted_contents, + **({} if not formatted_tool_calls else {"tool_calls": formatted_tool_calls}), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + def _format_request( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + ) -> dict[str, Any]: + """Format a request for the llama.cpp server. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + A request formatted for llama.cpp server's OpenAI-compatible API. + """ + # Separate OpenAI-compatible and llama.cpp-specific parameters + request = { + "messages": self._format_messages(messages, system_prompt), + "model": self.config["model_id"], + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs or [] + ], + } + + # Handle parameters if provided + params = self.config.get("params") + if params and isinstance(params, dict): + # Grammar and json_schema go directly in request body for llama.cpp server + if "grammar" in params: + request["grammar"] = params["grammar"] + if "json_schema" in params: + request["json_schema"] = params["json_schema"] + + # llama.cpp-specific parameters that must be passed via extra_body + # NOTE: grammar and json_schema are NOT in this set because llama.cpp server + # expects them directly in the request body for proper constraint application + llamacpp_specific_params = { + "repeat_penalty", + "top_k", + "min_p", + "typical_p", + "tfs_z", + "top_a", + "mirostat", + "mirostat_lr", + "mirostat_ent", + "penalty_last_n", + "n_probs", + "min_keep", + "ignore_eos", + "logit_bias", + "cache_prompt", + "slot_id", + "samplers", + } + + # Standard OpenAI parameters that go directly in the request + openai_params = { + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "seed", + "n", + "logprobs", + "top_logprobs", + "response_format", + } + + # Add OpenAI parameters directly to request + for param, value in params.items(): + if param in openai_params: + request[param] = value + + # Collect llama.cpp-specific parameters for extra_body + extra_body: Dict[str, Any] = {} + for param, value in params.items(): + if param in llamacpp_specific_params: + extra_body[param] = value + + # Add extra_body if we have llama.cpp-specific parameters + if extra_body: + request["extra_body"] = extra_body + + return request + + def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format a llama.cpp response event into a standardized message chunk. + + Args: + event: A response event from the llama.cpp server. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } + if event["data_type"] == "reasoning_content": + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": event.get("latency_ms", 0), + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the llama.cpp model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Formatted message chunks from the model. + + Raises: + ContextWindowOverflowException: When the context window is exceeded. + ModelThrottledException: When the llama.cpp server is overloaded. + """ + # Track request start time for latency calculation + start_time = time.perf_counter() + + try: + logger.debug("formatting request for llama.cpp server") + request = self._format_request(messages, tool_specs, system_prompt) + logger.debug("request=<%s>", request) + + logger.debug("sending request to llama.cpp server") + response = await self.client.post("/v1/chat/completions", json=request) + response.raise_for_status() + + logger.debug("processing streaming response") + yield self._format_chunk({"chunk_type": "message_start"}) + yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) + + tool_calls: Dict[int, list] = {} + usage_data = None + finish_reason = None + + async for line in response.aiter_lines(): + if not line.strip() or not line.startswith("data: "): + continue + + data_content = line[6:] # Remove "data: " prefix + if data_content.strip() == "[DONE]": + break + + try: + event = json.loads(data_content) + except json.JSONDecodeError: + continue + + # Handle usage information + if "usage" in event: + usage_data = event["usage"] + continue + + if not event.get("choices"): + continue + + choice = event["choices"][0] + delta = choice.get("delta", {}) + + # Handle content deltas + if "content" in delta and delta["content"]: + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "text", + "data": delta["content"], + } + ) + + # Handle tool calls + if "tool_calls" in delta: + for tool_call in delta["tool_calls"]: + index = tool_call["index"] + if index not in tool_calls: + tool_calls[index] = [] + tool_calls[index].append(tool_call) + + # Check for finish reason + if choice.get("finish_reason"): + finish_reason = choice.get("finish_reason") + break + + yield self._format_chunk({"chunk_type": "content_stop"}) + + # Process tool calls + for tool_deltas in tool_calls.values(): + first_delta = tool_deltas[0] + yield self._format_chunk( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": type( + "ToolCall", + (), + { + "function": type( + "Function", + (), + { + "name": first_delta.get("function", {}).get("name", ""), + }, + )(), + "id": first_delta.get("id", ""), + }, + )(), + } + ) + + for tool_delta in tool_deltas: + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": type( + "ToolCall", + (), + { + "function": type( + "Function", + (), + { + "arguments": tool_delta.get("function", {}).get("arguments", ""), + }, + )(), + }, + )(), + } + ) + + yield self._format_chunk({"chunk_type": "content_stop"}) + + # Send stop reason + logger.debug("finish_reason=%s, tool_calls=%s", finish_reason, bool(tool_calls)) + if finish_reason == "tool_calls" or tool_calls: + stop_reason = "tool_calls" # Changed from "tool_use" to match format_chunk expectations + else: + stop_reason = finish_reason or "end_turn" + logger.debug("stop_reason=%s", stop_reason) + yield self._format_chunk({"chunk_type": "message_stop", "data": stop_reason}) + + # Send usage metadata if available + if usage_data: + # Calculate latency + latency_ms = int((time.perf_counter() - start_time) * 1000) + yield self._format_chunk( + { + "chunk_type": "metadata", + "data": type( + "Usage", + (), + { + "prompt_tokens": usage_data.get("prompt_tokens", 0), + "completion_tokens": usage_data.get("completion_tokens", 0), + "total_tokens": usage_data.get("total_tokens", 0), + }, + )(), + "latency_ms": latency_ms, + } + ) + + logger.debug("finished streaming response") + + except httpx.HTTPStatusError as e: + if e.response.status_code == 400: + # Parse error response from llama.cpp server + try: + error_data = e.response.json() + error_msg = str(error_data.get("error", {}).get("message", str(error_data))) + except (json.JSONDecodeError, KeyError, AttributeError): + error_msg = e.response.text + + # Check for context overflow by looking for specific error indicators + if any(term in error_msg.lower() for term in ["context", "kv cache", "slot"]): + raise ContextWindowOverflowException(f"Context window exceeded: {error_msg}") from e + elif e.response.status_code == 503: + raise ModelThrottledException("llama.cpp server is busy or overloaded") from e + raise + except Exception as e: + # Handle other potential errors like rate limiting + error_msg = str(e).lower() + if "rate" in error_msg or "429" in str(e): + raise ModelThrottledException(str(e)) from e + raise + + @override + async def structured_output( + self, + output_model: Type[T], + prompt: Messages, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + """Get structured output using llama.cpp's native JSON schema support. + + This implementation uses llama.cpp's json_schema parameter to constrain + the model output to valid JSON matching the provided schema. + + Args: + output_model: The Pydantic model defining the expected output structure. + prompt: The prompt messages to use for generation. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Model events with the last being the structured output. + + Raises: + json.JSONDecodeError: If the model output is not valid JSON. + pydantic.ValidationError: If the output doesn't match the model schema. + """ + # Get the JSON schema from the Pydantic model + schema = output_model.model_json_schema() + + # Store current params to restore later + params = self.config.get("params", {}) + original_params = dict(params) if isinstance(params, dict) else {} + + try: + # Configure for JSON output with schema constraint + params = self.config.get("params", {}) + if not isinstance(params, dict): + params = {} + params["json_schema"] = schema + params["cache_prompt"] = True + self.config["params"] = params + + # Collect the response + response_text = "" + async for event in self.stream(prompt, system_prompt=system_prompt, **kwargs): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + # Forward events to caller + yield cast(Dict[str, Union[T, Any]], event) + + # Parse and validate the JSON response + data = json.loads(response_text.strip()) + output_instance = output_model(**data) + yield {"output": output_instance} + + finally: + # Restore original configuration + self.config["params"] = original_params diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py new file mode 100644 index 000000000..e5b2614c0 --- /dev/null +++ b/tests/strands/models/test_llamacpp.py @@ -0,0 +1,639 @@ +"""Unit tests for llama.cpp model provider.""" + +import base64 +import json +from unittest.mock import AsyncMock, patch + +import httpx +import pytest +from pydantic import BaseModel + +from strands.models.llamacpp import LlamaCppModel +from strands.types.exceptions import ( + ContextWindowOverflowException, + ModelThrottledException, +) + + +def test_init_default_config() -> None: + """Test initialization with default configuration.""" + model = LlamaCppModel() + + assert model.config["model_id"] == "default" + assert isinstance(model.client, httpx.AsyncClient) + assert model.base_url == "http://localhost:8080" + + +def test_init_custom_config() -> None: + """Test initialization with custom configuration.""" + model = LlamaCppModel( + base_url="http://example.com:8081", + model_id="llama-3-8b", + params={"temperature": 0.7, "max_tokens": 100}, + ) + + assert model.config["model_id"] == "llama-3-8b" + assert model.config["params"]["temperature"] == 0.7 + assert model.config["params"]["max_tokens"] == 100 + assert model.base_url == "http://example.com:8081" + + +def test_format_request_basic() -> None: + """Test basic request formatting.""" + model = LlamaCppModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + request = model._format_request(messages) + + assert request["model"] == "test-model" + assert request["messages"][0]["role"] == "user" + assert request["messages"][0]["content"][0]["type"] == "text" + assert request["messages"][0]["content"][0]["text"] == "Hello" + assert request["stream"] is True + assert "extra_body" not in request + + +def test_format_request_with_system_prompt() -> None: + """Test request formatting with system prompt.""" + model = LlamaCppModel() + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + request = model._format_request(messages, system_prompt="You are a helpful assistant") + + assert request["messages"][0]["role"] == "system" + assert request["messages"][0]["content"] == "You are a helpful assistant" + assert request["messages"][1]["role"] == "user" + + +def test_format_request_with_llamacpp_params() -> None: + """Test request formatting with llama.cpp specific parameters.""" + model = LlamaCppModel( + params={ + "temperature": 0.8, + "max_tokens": 50, + "repeat_penalty": 1.1, + "top_k": 40, + "min_p": 0.05, + "grammar": "root ::= 'yes' | 'no'", + } + ) + + messages = [ + {"role": "user", "content": [{"text": "Is the sky blue?"}]}, + ] + + request = model._format_request(messages) + + # Standard OpenAI params + assert request["temperature"] == 0.8 + assert request["max_tokens"] == 50 + + # Grammar and json_schema go directly in request for llama.cpp + assert request["grammar"] == "root ::= 'yes' | 'no'" + + # Other llama.cpp specific params should be in extra_body + assert "extra_body" in request + assert request["extra_body"]["repeat_penalty"] == 1.1 + assert request["extra_body"]["top_k"] == 40 + assert request["extra_body"]["min_p"] == 0.05 + + +def test_format_request_with_all_new_params() -> None: + """Test request formatting with all new llama.cpp parameters.""" + model = LlamaCppModel( + params={ + # OpenAI params + "temperature": 0.7, + "max_tokens": 100, + "top_p": 0.9, + "seed": 42, + # All llama.cpp specific params + "repeat_penalty": 1.1, + "top_k": 40, + "min_p": 0.05, + "typical_p": 0.95, + "tfs_z": 0.97, + "top_a": 0.1, + "mirostat": 2, + "mirostat_lr": 0.1, + "mirostat_ent": 5.0, + "grammar": "root ::= answer", + "json_schema": {"type": "object"}, + "penalty_last_n": 256, + "n_probs": 5, + "min_keep": 1, + "ignore_eos": False, + "logit_bias": {100: 5.0, 200: -5.0}, + "cache_prompt": True, + "slot_id": 1, + "samplers": ["top_k", "tfs_z", "typical_p"], + } + ) + + messages = [{"role": "user", "content": [{"text": "Test"}]}] + request = model._format_request(messages) + + # Check OpenAI params are in root + assert request["temperature"] == 0.7 + assert request["max_tokens"] == 100 + assert request["top_p"] == 0.9 + assert request["seed"] == 42 + + # Grammar and json_schema go directly in request for llama.cpp + assert request["grammar"] == "root ::= answer" + assert request["json_schema"] == {"type": "object"} + + # Check all other llama.cpp params are in extra_body + assert "extra_body" in request + extra = request["extra_body"] + assert extra["repeat_penalty"] == 1.1 + assert extra["top_k"] == 40 + assert extra["min_p"] == 0.05 + assert extra["typical_p"] == 0.95 + assert extra["tfs_z"] == 0.97 + assert extra["top_a"] == 0.1 + assert extra["mirostat"] == 2 + assert extra["mirostat_lr"] == 0.1 + assert extra["mirostat_ent"] == 5.0 + assert extra["penalty_last_n"] == 256 + assert extra["n_probs"] == 5 + assert extra["min_keep"] == 1 + assert extra["ignore_eos"] is False + assert extra["logit_bias"] == {100: 5.0, 200: -5.0} + assert extra["cache_prompt"] is True + assert extra["slot_id"] == 1 + assert extra["samplers"] == ["top_k", "tfs_z", "typical_p"] + + +def test_format_request_with_tools() -> None: + """Test request formatting with tool specifications.""" + model = LlamaCppModel() + + messages = [ + {"role": "user", "content": [{"text": "What's the weather?"}]}, + ] + + tool_specs = [ + { + "name": "get_weather", + "description": "Get current weather", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "location": {"type": "string"}, + }, + "required": ["location"], + } + }, + } + ] + + request = model._format_request(messages, tool_specs=tool_specs) + + assert "tools" in request + assert len(request["tools"]) == 1 + assert request["tools"][0]["function"]["name"] == "get_weather" + + +def test_update_config() -> None: + """Test configuration update.""" + model = LlamaCppModel(model_id="initial-model") + + assert model.config["model_id"] == "initial-model" + + model.update_config(model_id="updated-model", params={"temperature": 0.5}) + + assert model.config["model_id"] == "updated-model" + assert model.config["params"]["temperature"] == 0.5 + + +def test_get_config() -> None: + """Test configuration retrieval.""" + config = { + "model_id": "test-model", + "params": {"temperature": 0.9}, + } + model = LlamaCppModel(**config) + + retrieved_config = model.get_config() + + assert retrieved_config["model_id"] == "test-model" + assert retrieved_config["params"]["temperature"] == 0.9 + + +@pytest.mark.asyncio +async def test_stream_basic() -> None: + """Test basic streaming functionality.""" + model = LlamaCppModel() + + # Mock HTTP response with Server-Sent Events format + mock_response_lines = [ + 'data: {"choices": [{"delta": {"content": "Hello"}}]}', + 'data: {"choices": [{"delta": {"content": " world"}, "finish_reason": "stop"}]}', + 'data: {"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}}', + "data: [DONE]", + ] + + async def mock_aiter_lines(): + for line in mock_response_lines: + yield line + + mock_response = AsyncMock() + mock_response.aiter_lines = mock_aiter_lines + mock_response.raise_for_status = AsyncMock() + + with patch.object(model.client, "post", return_value=mock_response): + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + chunks = [] + async for chunk in model.stream(messages): + chunks.append(chunk) + + # Verify we got the expected chunks + assert any("messageStart" in chunk for chunk in chunks) + assert any( + "contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == "Hello" for chunk in chunks + ) + assert any( + "contentBlockDelta" in chunk and chunk["contentBlockDelta"]["delta"]["text"] == " world" for chunk in chunks + ) + assert any("messageStop" in chunk for chunk in chunks) + + +@pytest.mark.asyncio +async def test_structured_output() -> None: + """Test structured output functionality.""" + + class TestOutput(BaseModel): + """Test output model for structured output testing.""" + + answer: str + confidence: float + + model = LlamaCppModel() + + # Mock successful JSON response using the new structured_output implementation + mock_response_text = '{"answer": "yes", "confidence": 0.95}' + + # Create mock stream that returns JSON + async def mock_stream(*_args, **_kwargs): + # Verify json_schema was set + assert "json_schema" in model.config.get("params", {}) + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": mock_response_text}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + with patch.object(model, "stream", side_effect=mock_stream): + messages = [{"role": "user", "content": [{"text": "Is the earth round?"}]}] + + events = [] + async for event in model.structured_output(TestOutput, messages): + events.append(event) + + # Check we got the output + output_event = next((e for e in events if "output" in e), None) + assert output_event is not None + assert output_event["output"].answer == "yes" + assert output_event["output"].confidence == 0.95 + + +def test_timeout_configuration() -> None: + """Test timeout configuration.""" + # Test that timeout configuration is accepted without error + model = LlamaCppModel(timeout=30.0) + assert model.client.timeout is not None + + # Test with tuple timeout + model2 = LlamaCppModel(timeout=(10.0, 60.0)) + assert model2.client.timeout is not None + + +def test_max_retries_configuration() -> None: + """Test max retries configuration is handled gracefully.""" + # Since httpx doesn't use max_retries in the same way, + # we just test that the model initializes without error + model = LlamaCppModel() + assert model.config["model_id"] == "default" + + +def test_grammar_constraint_via_params() -> None: + """Test grammar constraint via params.""" + grammar = """ + root ::= answer + answer ::= "yes" | "no" + """ + model = LlamaCppModel(params={"grammar": grammar}) + + assert model.config["params"]["grammar"] == grammar + + # Update grammar via update_config + new_grammar = "root ::= [0-9]+" + model.update_config(params={"grammar": new_grammar}) + + assert model.config["params"]["grammar"] == new_grammar + + +def test_json_schema_via_params() -> None: + """Test JSON schema constraint via params.""" + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name", "age"], + } + model = LlamaCppModel(params={"json_schema": schema}) + + assert model.config["params"]["json_schema"] == schema + + +@pytest.mark.asyncio +async def test_stream_with_context_overflow_error() -> None: + """Test stream handling of context overflow errors.""" + model = LlamaCppModel() + + # Create HTTP error response + error_response = httpx.Response( + status_code=400, + json={"error": {"message": "Context window exceeded. Max context length is 4096 tokens"}}, + request=httpx.Request("POST", "http://test.com"), + ) + error = httpx.HTTPStatusError("Bad Request", request=error_response.request, response=error_response) + + # Mock the client to raise the error + with patch.object(model.client, "post", side_effect=error): + messages = [{"role": "user", "content": [{"text": "Very long message"}]}] + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Context window exceeded" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_stream_with_server_overload_error() -> None: + """Test stream handling of server overload errors.""" + model = LlamaCppModel() + + # Create HTTP error response for 503 + error_response = httpx.Response( + status_code=503, + text="Server is busy", + request=httpx.Request("POST", "http://test.com"), + ) + error = httpx.HTTPStatusError( + "Service Unavailable", + request=error_response.request, + response=error_response, + ) + + # Mock the client to raise the error + with patch.object(model.client, "post", side_effect=error): + messages = [{"role": "user", "content": [{"text": "Test"}]}] + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "server is busy or overloaded" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_structured_output_with_json_schema() -> None: + """Test structured output using JSON schema.""" + + class TestOutput(BaseModel): + """Test output model for JSON schema testing.""" + + answer: str + confidence: float + + model = LlamaCppModel() + + # Mock successful JSON response + mock_response_text = '{"answer": "yes", "confidence": 0.95}' + + # Create mock stream that returns JSON + async def mock_stream(*_args, **_kwargs): + # Check that json_schema was set correctly + assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": mock_response_text}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + with patch.object(model, "stream", side_effect=mock_stream): + messages = [{"role": "user", "content": [{"text": "Is the earth round?"}]}] + + events = [] + async for event in model.structured_output(TestOutput, messages): + events.append(event) + + # Check we got the output + output_event = next((e for e in events if "output" in e), None) + assert output_event is not None + assert output_event["output"].answer == "yes" + assert output_event["output"].confidence == 0.95 + + +@pytest.mark.asyncio +async def test_structured_output_invalid_json_error() -> None: + """Test structured output raises error for invalid JSON.""" + + class TestOutput(BaseModel): + """Test output model for invalid JSON testing.""" + + value: int + + model = LlamaCppModel() + + # Mock stream that returns invalid JSON + async def mock_stream(*_args, **_kwargs): + # Check that json_schema was set correctly + assert model.config["params"]["json_schema"] == TestOutput.model_json_schema() + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": "This is not valid JSON"}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + with patch.object(model, "stream", side_effect=mock_stream): + messages = [{"role": "user", "content": [{"text": "Give me a number"}]}] + + with pytest.raises(json.JSONDecodeError): + async for _ in model.structured_output(TestOutput, messages): + pass + + +def test_format_audio_content() -> None: + """Test formatting of audio content for llama.cpp multimodal models.""" + model = LlamaCppModel() + + # Create test audio data + audio_bytes = b"fake audio data" + audio_content = {"audio": {"source": {"bytes": audio_bytes}, "format": "wav"}} + + # Format the content + result = model._format_message_content(audio_content) + + # Verify the structure + assert result["type"] == "input_audio" + assert "input_audio" in result + assert "data" in result["input_audio"] + assert "format" in result["input_audio"] + + # Verify the data is base64 encoded + decoded = base64.b64decode(result["input_audio"]["data"]) + assert decoded == audio_bytes + + # Verify format is preserved + assert result["input_audio"]["format"] == "wav" + + +def test_format_audio_content_default_format() -> None: + """Test audio content formatting uses wav as default format.""" + model = LlamaCppModel() + + audio_content = { + "audio": {"source": {"bytes": b"test audio"}} + # No format specified + } + + result = model._format_message_content(audio_content) + + # Should default to wav + assert result["input_audio"]["format"] == "wav" + + +def test_format_messages_with_audio() -> None: + """Test that _format_messages properly handles audio content.""" + model = LlamaCppModel() + + # Create messages with audio content + messages = [ + { + "role": "user", + "content": [ + {"text": "Listen to this audio:"}, + {"audio": {"source": {"bytes": b"audio data"}, "format": "mp3"}}, + ], + } + ] + + # Format the messages + result = model._format_messages(messages) + + # Check structure + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 2 + + # Check text content + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "Listen to this audio:" + + # Check audio content + assert result[0]["content"][1]["type"] == "input_audio" + assert "input_audio" in result[0]["content"][1] + assert result[0]["content"][1]["input_audio"]["format"] == "mp3" + + +def test_format_messages_with_system_prompt() -> None: + """Test _format_messages includes system prompt.""" + model = LlamaCppModel() + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt = "You are a helpful assistant" + + result = model._format_messages(messages, system_prompt) + + # Should have system message first + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[0]["content"] == system_prompt + assert result[1]["role"] == "user" + + +def test_format_messages_with_image() -> None: + """Test that _format_messages properly handles image content.""" + model = LlamaCppModel() + + # Create messages with image content + messages = [ + { + "role": "user", + "content": [ + {"text": "Describe this image:"}, + {"image": {"source": {"bytes": b"image data"}, "format": "png"}}, + ], + } + ] + + # Format the messages + result = model._format_messages(messages) + + # Check structure + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 2 + + # Check text content + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "Describe this image:" + + # Check image content uses standard format + assert result[0]["content"][1]["type"] == "image_url" + assert "image_url" in result[0]["content"][1] + assert "url" in result[0]["content"][1]["image_url"] + assert result[0]["content"][1]["image_url"]["url"].startswith("data:image/png;base64,") + + +def test_format_messages_with_mixed_content() -> None: + """Test that _format_messages handles mixed audio and image content correctly.""" + model = LlamaCppModel() + + # Create messages with both audio and image content + messages = [ + { + "role": "user", + "content": [ + {"text": "Analyze this media:"}, + {"audio": {"source": {"bytes": b"audio data"}, "format": "wav"}}, + {"image": {"source": {"bytes": b"image data"}, "format": "jpg"}}, + ], + } + ] + + # Format the messages + result = model._format_messages(messages) + + # Check structure + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 3 + + # Check text content + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "Analyze this media:" + + # Check audio content uses llama.cpp specific format + assert result[0]["content"][1]["type"] == "input_audio" + assert "input_audio" in result[0]["content"][1] + assert result[0]["content"][1]["input_audio"]["format"] == "wav" + + # Check image content uses standard OpenAI format + assert result[0]["content"][2]["type"] == "image_url" + assert "image_url" in result[0]["content"][2] + assert result[0]["content"][2]["image_url"]["url"].startswith("data:image/jpeg;base64,") diff --git a/tests_integ/models/test_model_llamacpp.py b/tests_integ/models/test_model_llamacpp.py new file mode 100644 index 000000000..95047e7ab --- /dev/null +++ b/tests_integ/models/test_model_llamacpp.py @@ -0,0 +1,510 @@ +"""Integration tests for llama.cpp model provider. + +These tests require a running llama.cpp server instance. +To run these tests: +1. Start llama.cpp server: llama-server -m model.gguf --host 0.0.0.0 --port 8080 +2. Run: pytest tests_integ/models/test_model_llamacpp.py + +Set LLAMACPP_TEST_URL environment variable to use a different server URL. +""" + +import os + +import pytest +from pydantic import BaseModel + +from strands.models.llamacpp import LlamaCppModel +from strands.types.content import Message + +# Get server URL from environment or use default +LLAMACPP_URL = os.environ.get("LLAMACPP_TEST_URL", "http://localhost:8080/v1") + +# Skip these tests if LLAMACPP_SKIP_TESTS is set +pytestmark = pytest.mark.skipif( + os.environ.get("LLAMACPP_SKIP_TESTS", "true").lower() == "true", + reason="llama.cpp integration tests disabled (set LLAMACPP_SKIP_TESTS=false to enable)", +) + + +class WeatherOutput(BaseModel): + """Test output model for structured responses.""" + + temperature: float + condition: str + location: str + + +@pytest.fixture +async def llamacpp_model() -> LlamaCppModel: + """Fixture to create a llama.cpp model instance.""" + return LlamaCppModel(base_url=LLAMACPP_URL) + + +# Integration tests for LlamaCppModel with a real server + + +@pytest.mark.asyncio +async def test_basic_completion(llamacpp_model: LlamaCppModel) -> None: + """Test basic text completion.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Say 'Hello, World!' and nothing else."}]}, + ] + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + assert "Hello, World!" in response_text + + +@pytest.mark.asyncio +async def test_system_prompt(llamacpp_model: LlamaCppModel) -> None: + """Test completion with system prompt.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Who are you?"}]}, + ] + + system_prompt = "You are a helpful AI assistant named Claude." + + response_text = "" + async for event in llamacpp_model.stream(messages, system_prompt=system_prompt): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Response should reflect the system prompt + assert len(response_text) > 0 + assert "assistant" in response_text.lower() or "claude" in response_text.lower() + + +@pytest.mark.asyncio +async def test_streaming_chunks(llamacpp_model: LlamaCppModel) -> None: + """Test that streaming returns proper chunk sequence.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Count from 1 to 3."}]}, + ] + + chunk_types = [] + async for event in llamacpp_model.stream(messages): + chunk_types.append(next(iter(event.keys()))) + + # Verify proper chunk sequence + assert chunk_types[0] == "messageStart" + assert chunk_types[1] == "contentBlockStart" + assert "contentBlockDelta" in chunk_types + assert chunk_types[-3] == "contentBlockStop" + assert chunk_types[-2] == "messageStop" + assert chunk_types[-1] == "metadata" + + +@pytest.mark.asyncio +async def test_temperature_parameter(llamacpp_model: LlamaCppModel) -> None: + """Test temperature parameter affects randomness.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Generate a random word."}]}, + ] + + # Low temperature should give more consistent results + llamacpp_model.update_config(params={"temperature": 0.1, "seed": 42}) + + response1 = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response1 += delta["text"] + + # Same seed and low temperature should give similar result + response2 = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response2 += delta["text"] + + # With low temperature and same seed, responses should be very similar + assert len(response1) > 0 + assert len(response2) > 0 + + +@pytest.mark.asyncio +async def test_max_tokens_limit(llamacpp_model: LlamaCppModel) -> None: + """Test max_tokens parameter limits response length.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Tell me a very long story about dragons."}]}, + ] + + # Set very low token limit + llamacpp_model.update_config(params={"max_tokens": 10}) + + token_count = 0 + async for event in llamacpp_model.stream(messages): + if "metadata" in event: + usage = event["metadata"]["usage"] + token_count = usage["outputTokens"] + if "messageStop" in event: + stop_reason = event["messageStop"]["stopReason"] + + # Should stop due to max_tokens + assert token_count <= 15 # Allow small overage due to tokenization + assert stop_reason == "max_tokens" + + +@pytest.mark.asyncio +async def test_structured_output(llamacpp_model: LlamaCppModel) -> None: + """Test structured output generation.""" + messages: list[Message] = [ + { + "role": "user", + "content": [ + { + "text": "What's the weather like in Paris? " + "Respond with temperature in Celsius, condition, and location." + } + ], + }, + ] + + # Enable JSON response format for structured output + llamacpp_model.update_config(params={"response_format": {"type": "json_object"}}) + + result = None + async for event in llamacpp_model.structured_output(WeatherOutput, messages): + if "output" in event: + result = event["output"] + + assert result is not None + assert isinstance(result, WeatherOutput) + assert isinstance(result.temperature, float) + assert isinstance(result.condition, str) + assert result.location.lower() == "paris" + + +@pytest.mark.asyncio +async def test_llamacpp_specific_params(llamacpp_model: LlamaCppModel) -> None: + """Test llama.cpp specific parameters.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Say 'test' five times."}]}, + ] + + # Use llama.cpp specific parameters + llamacpp_model.update_config( + params={ + "repeat_penalty": 1.5, # Penalize repetition + "top_k": 10, # Limit vocabulary + "min_p": 0.1, # Min-p sampling + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Response should contain "test" but with repetition penalty it might vary + assert "test" in response_text.lower() + + +@pytest.mark.asyncio +async def test_advanced_sampling_params(llamacpp_model: LlamaCppModel) -> None: + """Test advanced sampling parameters.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Generate a random sentence about space."}]}, + ] + + # Test advanced sampling parameters + llamacpp_model.update_config( + params={ + "temperature": 0.8, + "tfs_z": 0.95, # Tail-free sampling + "top_a": 0.1, # Top-a sampling + "typical_p": 0.9, # Typical-p sampling + "penalty_last_n": 64, # Penalty context window + "min_keep": 1, # Minimum tokens to keep + "samplers": ["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"], + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should generate something about space + assert len(response_text) > 0 + assert any(word in response_text.lower() for word in ["space", "star", "planet", "galaxy", "universe"]) + + +@pytest.mark.asyncio +async def test_mirostat_sampling(llamacpp_model: LlamaCppModel) -> None: + """Test Mirostat sampling modes.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Write a short poem."}]}, + ] + + # Test Mirostat v2 + llamacpp_model.update_config( + params={ + "mirostat": 2, + "mirostat_lr": 0.1, + "mirostat_ent": 5.0, + "seed": 42, # For reproducibility + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should generate a poem + assert len(response_text) > 20 + assert "\n" in response_text # Poems typically have line breaks + + +@pytest.mark.asyncio +async def test_grammar_constraint(llamacpp_model: LlamaCppModel) -> None: + """Test grammar constraint feature (llama.cpp specific).""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Is the sky blue? Answer yes or no."}]}, + ] + + # Set grammar constraint via params + grammar = """ + root ::= answer + answer ::= "yes" | "no" + """ + llamacpp_model.update_config(params={"grammar": grammar}) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Response should be exactly "yes" or "no" + assert response_text.strip().lower() in ["yes", "no"] + + +@pytest.mark.asyncio +async def test_json_schema_constraint(llamacpp_model: LlamaCppModel) -> None: + """Test JSON schema constraint feature.""" + messages: list[Message] = [ + { + "role": "user", + "content": [{"text": "Describe the weather in JSON format with temperature and description."}], + }, + ] + + # Set JSON schema constraint via params + schema = { + "type": "object", + "properties": {"temperature": {"type": "number"}, "description": {"type": "string"}}, + "required": ["temperature", "description"], + } + llamacpp_model.update_config(params={"json_schema": schema}) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should be valid JSON matching the schema + import json + + data = json.loads(response_text.strip()) + assert "temperature" in data + assert "description" in data + assert isinstance(data["temperature"], (int, float)) + assert isinstance(data["description"], str) + + +@pytest.mark.asyncio +async def test_logit_bias(llamacpp_model: LlamaCppModel) -> None: + """Test logit bias feature.""" + messages: list[Message] = [ + {"role": "user", "content": [{"text": "Choose between 'cat' and 'dog'."}]}, + ] + + # This is a simplified test - in reality you'd need to know the actual token IDs + # for "cat" and "dog" in the model's vocabulary + llamacpp_model.update_config( + params={ + "logit_bias": { + # These are placeholder token IDs - real implementation would need actual token IDs + 1234: 10.0, # Strong positive bias (hypothetical "cat" token) + 5678: -10.0, # Strong negative bias (hypothetical "dog" token) + }, + "seed": 42, # For reproducibility + } + ) + + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # Should generate text (exact behavior depends on actual token IDs) + assert len(response_text) > 0 + + +@pytest.mark.asyncio +async def test_cache_prompt(llamacpp_model: LlamaCppModel) -> None: + """Test prompt caching feature.""" + messages: list[Message] = [ + {"role": "system", "content": [{"text": "You are a helpful assistant. Always be concise."}]}, + {"role": "user", "content": [{"text": "What is 2+2?"}]}, + ] + + # Enable prompt caching + llamacpp_model.update_config( + params={ + "cache_prompt": True, + "slot_id": 0, # Use specific slot for caching + } + ) + + # First request + response1 = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response1 += delta["text"] + + # Second request with same system prompt should use cache + messages2 = [ + {"role": "system", "content": [{"text": "You are a helpful assistant. Always be concise."}]}, + {"role": "user", "content": [{"text": "What is 3+3?"}]}, + ] + + response2 = "" + async for event in llamacpp_model.stream(messages2): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response2 += delta["text"] + + # Both should give valid responses + assert "4" in response1 + assert "6" in response2 + + +@pytest.mark.asyncio +async def test_concurrent_requests(llamacpp_model: LlamaCppModel) -> None: + """Test handling multiple concurrent requests.""" + import asyncio + + async def make_request(prompt: str) -> str: + messages: list[Message] = [ + {"role": "user", "content": [{"text": prompt}]}, + ] + + response = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response += delta["text"] + return response + + # Make concurrent requests + prompts = [ + "Say 'one'", + "Say 'two'", + "Say 'three'", + ] + + responses = await asyncio.gather(*[make_request(p) for p in prompts]) + + # Each response should contain the expected number + assert "one" in responses[0].lower() + assert "two" in responses[1].lower() + assert "three" in responses[2].lower() + + +@pytest.mark.asyncio +async def test_enhanced_structured_output(llamacpp_model: LlamaCppModel) -> None: + """Test enhanced structured output with native JSON schema support.""" + + class BookInfo(BaseModel): + title: str + author: str + year: int + genres: list[str] + + messages: list[Message] = [ + { + "role": "user", + "content": [ + { + "text": "Create information about a fictional science fiction book. " + "Include title, author, publication year, and 2-3 genres." + } + ], + }, + ] + + result = None + events = [] + async for event in llamacpp_model.structured_output(BookInfo, messages): + events.append(event) + if "output" in event: + result = event["output"] + + # Verify we got structured output + assert result is not None + assert isinstance(result, BookInfo) + assert isinstance(result.title, str) and len(result.title) > 0 + assert isinstance(result.author, str) and len(result.author) > 0 + assert isinstance(result.year, int) and 1900 <= result.year <= 2100 + assert isinstance(result.genres, list) and len(result.genres) >= 2 + assert all(isinstance(genre, str) for genre in result.genres) + + # Should have streamed events before the output + assert len(events) > 1 + + +@pytest.mark.asyncio +async def test_context_overflow_handling(llamacpp_model: LlamaCppModel) -> None: + """Test proper handling of context window overflow.""" + # Create a very long message that might exceed context + long_text = "This is a test sentence. " * 1000 + messages: list[Message] = [ + {"role": "user", "content": [{"text": f"Summarize this text: {long_text}"}]}, + ] + + try: + response_text = "" + async for event in llamacpp_model.stream(messages): + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + response_text += delta["text"] + + # If it succeeds, we got a response + assert len(response_text) > 0 + except Exception as e: + # If it fails, it should be our custom error + from strands.types.exceptions import ContextWindowOverflowException + + if isinstance(e, ContextWindowOverflowException): + assert "context" in str(e).lower() + else: + # Some other error - re-raise to see what it was + raise From 4fbe46a8b99c17aa28330a347fa1b6f5a0247c1e Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:18:01 -0400 Subject: [PATCH 089/104] fix(llama.cpp) - add ToolChoice and validation of model config values (#838) --- src/strands/models/llamacpp.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 94a225a06..25d42a6c8 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -33,7 +33,8 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec +from ..types.tools import ToolChoice, ToolSpec +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -149,12 +150,15 @@ def __init__( (connect, read) timeouts. **model_config: Configuration options for the llama.cpp model. """ + validate_config_keys(model_config, self.LlamaCppConfig) + # Set default model_id if not provided if "model_id" not in model_config: model_config["model_id"] = "default" self.base_url = base_url.rstrip("/") self.config = dict(model_config) + logger.debug("config=<%s> | initializing", self.config) # Configure HTTP client if isinstance(timeout, tuple): @@ -173,12 +177,6 @@ def __init__( timeout=timeout_obj, ) - logger.debug( - "base_url=<%s>, model_id=<%s> | initializing llama.cpp provider", - base_url, - model_config.get("model_id"), - ) - @override def update_config(self, **model_config: Unpack[LlamaCppConfig]) -> None: # type: ignore[override] """Update the llama.cpp model configuration with provided arguments. @@ -186,6 +184,7 @@ def update_config(self, **model_config: Unpack[LlamaCppConfig]) -> None: # type Args: **model_config: Configuration overrides. """ + validate_config_keys(model_config, self.LlamaCppConfig) self.config.update(model_config) @override @@ -514,6 +513,7 @@ async def stream( messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the llama.cpp model. @@ -522,6 +522,8 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for + interface consistency but is currently ignored for this model provider.** **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -531,19 +533,21 @@ async def stream( ContextWindowOverflowException: When the context window is exceeded. ModelThrottledException: When the llama.cpp server is overloaded. """ + warn_on_tool_choice_not_supported(tool_choice) + # Track request start time for latency calculation start_time = time.perf_counter() try: - logger.debug("formatting request for llama.cpp server") + logger.debug("formatting request") request = self._format_request(messages, tool_specs, system_prompt) logger.debug("request=<%s>", request) - logger.debug("sending request to llama.cpp server") + logger.debug("invoking model") response = await self.client.post("/v1/chat/completions", json=request) response.raise_for_status() - logger.debug("processing streaming response") + logger.debug("got response from model") yield self._format_chunk({"chunk_type": "message_start"}) yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) @@ -648,12 +652,10 @@ async def stream( yield self._format_chunk({"chunk_type": "content_stop"}) # Send stop reason - logger.debug("finish_reason=%s, tool_calls=%s", finish_reason, bool(tool_calls)) if finish_reason == "tool_calls" or tool_calls: stop_reason = "tool_calls" # Changed from "tool_use" to match format_chunk expectations else: stop_reason = finish_reason or "end_turn" - logger.debug("stop_reason=%s", stop_reason) yield self._format_chunk({"chunk_type": "message_stop", "data": stop_reason}) # Send usage metadata if available @@ -676,7 +678,7 @@ async def stream( } ) - logger.debug("finished streaming response") + logger.debug("finished streaming response from model") except httpx.HTTPStatusError as e: if e.response.status_code == 400: From bf4e3e4128891df79753d064f26610769875e93b Mon Sep 17 00:00:00 2001 From: Vamil Gandhi Date: Thu, 11 Sep 2025 11:06:06 -0400 Subject: [PATCH 090/104] feat(telemetry): add cache usage metrics to OpenTelemetry spans (#825) Adds cacheReadInputTokens and cacheWriteInputTokens to span attributes in both end_model_invoke_span and end_agent_span methods to enable monitoring of cache token usage for cost calculation. Closes #776 Co-authored-by: Vamil Gandhi --- src/strands/telemetry/tracer.py | 4 ++ tests/strands/telemetry/test_tracer.py | 62 ++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 6b429393d..9e170571a 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -263,6 +263,8 @@ def end_model_invoke_span( "gen_ai.usage.completion_tokens": usage["outputTokens"], "gen_ai.usage.output_tokens": usage["outputTokens"], "gen_ai.usage.total_tokens": usage["totalTokens"], + "gen_ai.usage.cache_read_input_tokens": usage.get("cacheReadInputTokens", 0), + "gen_ai.usage.cache_write_input_tokens": usage.get("cacheWriteInputTokens", 0), } self._add_event( @@ -491,6 +493,8 @@ def end_agent_span( "gen_ai.usage.input_tokens": accumulated_usage["inputTokens"], "gen_ai.usage.output_tokens": accumulated_usage["outputTokens"], "gen_ai.usage.total_tokens": accumulated_usage["totalTokens"], + "gen_ai.usage.cache_read_input_tokens": accumulated_usage.get("cacheReadInputTokens", 0), + "gen_ai.usage.cache_write_input_tokens": accumulated_usage.get("cacheWriteInputTokens", 0), } ) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 586911bef..568fff130 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -177,6 +177,8 @@ def test_end_model_invoke_span(mock_span): mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, @@ -404,6 +406,8 @@ def test_end_agent_span(mock_span): mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) mock_span.add_event.assert_any_call( "gen_ai.choice", attributes={"message": "Agent response", "finish_reason": "end_turn"}, @@ -412,6 +416,64 @@ def test_end_agent_span(mock_span): mock_span.end.assert_called_once() +def test_end_model_invoke_span_with_cache_metrics(mock_span): + """Test ending a model invoke span with cache metrics.""" + tracer = Tracer() + message = {"role": "assistant", "content": [{"text": "Response"}]} + usage = Usage( + inputTokens=10, + outputTokens=20, + totalTokens=30, + cacheReadInputTokens=5, + cacheWriteInputTokens=3, + ) + stop_reason: StopReason = "end_turn" + + tracer.end_model_invoke_span(mock_span, message, usage, stop_reason) + + mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) + mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) + mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) + mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) + mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 5) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 3) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() + + +def test_end_agent_span_with_cache_metrics(mock_span): + """Test ending an agent span with cache metrics.""" + tracer = Tracer() + + # Mock AgentResult with metrics including cache tokens + mock_metrics = mock.MagicMock() + mock_metrics.accumulated_usage = { + "inputTokens": 50, + "outputTokens": 100, + "totalTokens": 150, + "cacheReadInputTokens": 25, + "cacheWriteInputTokens": 10, + } + + mock_response = mock.MagicMock() + mock_response.metrics = mock_metrics + mock_response.stop_reason = "end_turn" + mock_response.__str__ = mock.MagicMock(return_value="Agent response") + + tracer.end_agent_span(mock_span, mock_response) + + mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) + mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50) + mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) + mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) + mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 25) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 10) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() + + def test_get_tracer_singleton(): """Test that get_tracer returns a singleton instance.""" # Reset the singleton first From 7f77a593e4aefec470573e1bafd2935f63f383b5 Mon Sep 17 00:00:00 2001 From: Himanshu <101276134+waitasecant@users.noreply.github.com> Date: Fri, 12 Sep 2025 00:15:42 +0530 Subject: [PATCH 091/104] docs: improve docstring formatting for `invoke_async` function in `Agent` class. [for better VS Code hover] (#846) --- src/strands/agent/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 05e15a5b1..bb602d66b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -425,7 +425,7 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR **kwargs: Additional parameters to pass through the event loop. Returns: - Result object containing: + Result: object containing: - stop_reason: Why the event loop stopped (e.g., "end_turn", "max_tokens") - message: The final message from the model From 7d1bdbf0e89fd46caeabefd07d19a5c078633c56 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:57:34 -0400 Subject: [PATCH 092/104] ci: bump actions/setup-python from 5 to 6 (#796) Bumps [actions/setup-python](https://github.com/actions/setup-python) from 5 to 6. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v5...v6) --- updated-dependencies: - dependency-name: actions/setup-python dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/integration-test.yml | 2 +- .github/workflows/pypi-publish-on-release.yml | 2 +- .github/workflows/test-lint.yml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index d410bb712..0befb4810 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -57,7 +57,7 @@ jobs: ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo persist-credentials: false # Don't persist credentials for subsequent actions - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.10' - name: Install dependencies diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index c2420d747..ff19e46b1 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -27,7 +27,7 @@ jobs: persist-credentials: false - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.10' diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index c0ed4faca..1d1eb8973 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -56,7 +56,7 @@ jobs: ref: ${{ inputs.ref }} # Explicitly define which commit to check out persist-credentials: false # Don't persist credentials for subsequent actions - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -79,7 +79,7 @@ jobs: persist-credentials: false - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.10' cache: 'pip' From eace0ecfaba239fc679e003040436b51c1b04b02 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:57:49 -0400 Subject: [PATCH 093/104] ci: bump actions/github-script from 7 to 8 (#801) Bumps [actions/github-script](https://github.com/actions/github-script) from 7 to 8. - [Release notes](https://github.com/actions/github-script/releases) - [Commits](https://github.com/actions/github-script/compare/v7...v8) --- updated-dependencies: - dependency-name: actions/github-script dependency-version: '8' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/integration-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 0befb4810..dc2f20c7a 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -12,7 +12,7 @@ jobs: approval-env: ${{ steps.collab-check.outputs.result }} steps: - name: Collaborator Check - uses: actions/github-script@v7 + uses: actions/github-script@v8 id: collab-check with: result-encoding: string From fe7a700e4d88e8ac5f8e2b3af74c8ff674d6ab47 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:11:53 -0400 Subject: [PATCH 094/104] ci: bump aws-actions/configure-aws-credentials from 4 to 5 (#795) Bumps [aws-actions/configure-aws-credentials](https://github.com/aws-actions/configure-aws-credentials) from 4 to 5. - [Release notes](https://github.com/aws-actions/configure-aws-credentials/releases) - [Changelog](https://github.com/aws-actions/configure-aws-credentials/blob/main/CHANGELOG.md) - [Commits](https://github.com/aws-actions/configure-aws-credentials/compare/v4...v5) --- updated-dependencies: - dependency-name: aws-actions/configure-aws-credentials dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/integration-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index dc2f20c7a..7496e45ef 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -46,7 +46,7 @@ jobs: contents: read steps: - name: Configure Credentials - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@v5 with: role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} aws-region: us-east-1 From f12fee856dd6d6749c771cc65c809a5d52f851ae Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 12 Sep 2025 12:10:46 -0400 Subject: [PATCH 095/104] fix: Add type to tool_input (#854) --- src/strands/tools/decorator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 8b218dfa1..4923a44ee 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -447,7 +447,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw """ # This is a tool use call - process accordingly tool_use_id = tool_use.get("toolUseId", "unknown") - tool_input = tool_use.get("input", {}) + tool_input: dict[str, Any] = tool_use.get("input", {}) try: # Validate input against the Pydantic model From cbdab3255602344c782a89499159018e8fb57dcc Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 12 Sep 2025 18:17:17 +0200 Subject: [PATCH 096/104] feat(swarm): Make entry point configurable (#851) --- src/strands/multiagent/swarm.py | 28 +++++++++- tests/strands/multiagent/test_swarm.py | 76 ++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 1 deletion(-) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 1c2302c28..620fa5e24 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -196,6 +196,7 @@ def __init__( self, nodes: list[Agent], *, + entry_point: Agent | None = None, max_handoffs: int = 20, max_iterations: int = 20, execution_timeout: float = 900.0, @@ -207,6 +208,7 @@ def __init__( Args: nodes: List of nodes (e.g. Agent) to include in the swarm + entry_point: Agent to start with. If None, uses the first agent (default: None) max_handoffs: Maximum handoffs to agents and users (default: 20) max_iterations: Maximum node executions within the swarm (default: 20) execution_timeout: Total execution timeout in seconds (default: 900.0) @@ -218,6 +220,7 @@ def __init__( """ super().__init__() + self.entry_point = entry_point self.max_handoffs = max_handoffs self.max_iterations = max_iterations self.execution_timeout = execution_timeout @@ -276,7 +279,11 @@ async def invoke_async( logger.debug("starting swarm execution") # Initialize swarm state with configuration - initial_node = next(iter(self.nodes.values())) # First SwarmNode + if self.entry_point: + initial_node = self.nodes[str(self.entry_point.name)] + else: + initial_node = next(iter(self.nodes.values())) # First SwarmNode + self.state = SwarmState( current_node=initial_node, task=task, @@ -326,9 +333,28 @@ def _setup_swarm(self, nodes: list[Agent]) -> None: self.nodes[node_id] = SwarmNode(node_id=node_id, executor=node) + # Validate entry point if specified + if self.entry_point is not None: + entry_point_node_id = str(self.entry_point.name) + if ( + entry_point_node_id not in self.nodes + or self.nodes[entry_point_node_id].executor is not self.entry_point + ): + available_agents = [ + f"{node_id} ({type(node.executor).__name__})" for node_id, node in self.nodes.items() + ] + raise ValueError(f"Entry point agent not found in swarm nodes. Available agents: {available_agents}") + swarm_nodes = list(self.nodes.values()) logger.debug("nodes=<%s> | initialized swarm with nodes", [node.node_id for node in swarm_nodes]) + if self.entry_point: + entry_point_name = getattr(self.entry_point, "name", "unnamed_agent") + logger.debug("entry_point=<%s> | configured entry point", entry_point_name) + else: + first_node = next(iter(self.nodes.keys())) + logger.debug("entry_point=<%s> | using first node as entry point", first_node) + def _validate_swarm(self, nodes: list[Agent]) -> None: """Validate swarm structure and nodes.""" # Check for duplicate object instances diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index be463c7fd..7d3e69695 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -451,6 +451,82 @@ def test_swarm_auto_completion_without_handoff(): no_handoff_agent.invoke_async.assert_called() +def test_swarm_configurable_entry_point(): + """Test swarm with configurable entry point.""" + # Create multiple agents + agent1 = create_mock_agent("agent1", "Agent 1 response") + agent2 = create_mock_agent("agent2", "Agent 2 response") + agent3 = create_mock_agent("agent3", "Agent 3 response") + + # Create swarm with agent2 as entry point + swarm = Swarm([agent1, agent2, agent3], entry_point=agent2) + + # Verify entry point is set correctly + assert swarm.entry_point is agent2 + + # Execute swarm + result = swarm("Test task") + + # Verify agent2 was the first to execute + assert result.status == Status.COMPLETED + assert len(result.node_history) == 1 + assert result.node_history[0].node_id == "agent2" + + +def test_swarm_invalid_entry_point(): + """Test swarm with invalid entry point raises error.""" + agent1 = create_mock_agent("agent1", "Agent 1 response") + agent2 = create_mock_agent("agent2", "Agent 2 response") + agent3 = create_mock_agent("agent3", "Agent 3 response") # Not in swarm + + # Try to create swarm with agent not in the swarm + with pytest.raises(ValueError, match="Entry point agent not found in swarm nodes"): + Swarm([agent1, agent2], entry_point=agent3) + + +def test_swarm_default_entry_point(): + """Test swarm uses first agent as default entry point.""" + agent1 = create_mock_agent("agent1", "Agent 1 response") + agent2 = create_mock_agent("agent2", "Agent 2 response") + + # Create swarm without specifying entry point + swarm = Swarm([agent1, agent2]) + + # Verify no explicit entry point is set + assert swarm.entry_point is None + + # Execute swarm + result = swarm("Test task") + + # Verify first agent was used as entry point + assert result.status == Status.COMPLETED + assert len(result.node_history) == 1 + assert result.node_history[0].node_id == "agent1" + + +def test_swarm_duplicate_agent_names(): + """Test swarm rejects agents with duplicate names.""" + agent1 = create_mock_agent("duplicate_name", "Agent 1 response") + agent2 = create_mock_agent("duplicate_name", "Agent 2 response") + + # Try to create swarm with duplicate names + with pytest.raises(ValueError, match="Node ID 'duplicate_name' is not unique"): + Swarm([agent1, agent2]) + + +def test_swarm_entry_point_same_name_different_object(): + """Test entry point validation with same name but different object.""" + agent1 = create_mock_agent("agent1", "Agent 1 response") + agent2 = create_mock_agent("agent2", "Agent 2 response") + + # Create a different agent with same name as agent1 + different_agent_same_name = create_mock_agent("agent1", "Different agent response") + + # Try to use the different agent as entry point + with pytest.raises(ValueError, match="Entry point agent not found in swarm nodes"): + Swarm([agent1, agent2], entry_point=different_agent_same_name) + + def test_swarm_validate_unsupported_features(): """Test Swarm validation for session persistence and callbacks.""" # Test with normal agent (should work) From 5790a9c0ba8399dbd33f5f584cfd7736aa88cd0e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 12 Sep 2025 12:29:43 -0400 Subject: [PATCH 097/104] ci: update ruff requirement from <0.13.0,>=0.12.0 to >=0.12.0,<0.14.0 (#840) Updates the requirements on [ruff](https://github.com/astral-sh/ruff) to permit the latest version. - [Release notes](https://github.com/astral-sh/ruff/releases) - [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md) - [Commits](https://github.com/astral-sh/ruff/compare/0.12.0...0.13.0) --- updated-dependencies: - dependency-name: ruff dependency-version: 0.13.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a0be0ddc6..ac6c3f97a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ dev = [ "pytest-cov>=6.0.0,<7.0.0", "pytest-asyncio>=1.0.0,<1.2.0", "pytest-xdist>=3.0.0,<4.0.0", - "ruff>=0.12.0,<0.13.0", + "ruff>=0.12.0,<0.14.0", ] docs = [ "sphinx>=5.0.0,<6.0.0", From 6a1b2d44d830bcd6bdbeec6ab0342525d63caf4e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 12 Sep 2025 12:32:35 -0400 Subject: [PATCH 098/104] ci: update openai requirement (#827) Updates the requirements on [openai](https://github.com/openai/openai-python) to permit the latest version. - [Release notes](https://github.com/openai/openai-python/releases) - [Changelog](https://github.com/openai/openai-python/blob/main/CHANGELOG.md) - [Commits](https://github.com/openai/openai-python/compare/v1.68.0...v1.107.0) --- updated-dependencies: - dependency-name: openai dependency-version: 1.107.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ac6c3f97a..151a80530 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ docs = [ ] litellm = [ "litellm>=1.75.9,<2.0.0", - "openai>=1.68.0,<1.102.0", + "openai>=1.68.0,<1.108.0", ] llamaapi = [ "llama-api-client>=0.1.0,<1.0.0", From 066a427cbb074b5b65c7a14f1bac02796c63315e Mon Sep 17 00:00:00 2001 From: Jonathan Segev Date: Fri, 12 Sep 2025 12:55:07 -0400 Subject: [PATCH 099/104] feat: add automated issue auto-close workflows with dry-run testing (#832) * feat: add GitHub workflow for auto-closing stale issues with dry-run support - Daily workflow checks issues with configurable labels after X days - Removes label if unauthorized users comment, closes if only authorized users - Supports team-based or write-access authorization modes - Includes comprehensive input validation and error handling - Adds manual trigger with dry-run mode for safe testing * fix: Replace deprecated GitHub Search API with Issues API - Replace github.rest.search.issuesAndPullRequests with github.rest.issues.listForRepo - Add pagination support to handle repositories with many labeled issues * feat: remove label immediately on unauthorized comments - Check for unauthorized comments before time validation - Remove the label instantly when non-authorized users respond * feat: add optional replacement label when removing auto-close label - Add REPLACEMENT_LABEL environment variable for optional label substitution - Apply replacement label when unauthorized users comment and auto-close label is removed * feat: Consolidate auto-close workflows into a single matrix-based action - Merge auto-close-3-days.yml and auto-close-7-days.yml into auto-close.yml - Use a matrix strategy to handle both 3-day and 7-day label processing --- .github/workflows/auto-close.yml | 237 +++++++++++++++++++++++++++++++ 1 file changed, 237 insertions(+) create mode 100644 .github/workflows/auto-close.yml diff --git a/.github/workflows/auto-close.yml b/.github/workflows/auto-close.yml new file mode 100644 index 000000000..5c402f619 --- /dev/null +++ b/.github/workflows/auto-close.yml @@ -0,0 +1,237 @@ +name: Auto Close Issues + +on: + schedule: + - cron: '0 14 * * 1-5' # 9 AM EST (2 PM UTC) Monday through Friday + workflow_dispatch: + inputs: + dry_run: + description: 'Run in dry-run mode (no actions taken, only logging)' + required: false + default: 'false' + type: boolean + +jobs: + auto-close: + runs-on: ubuntu-latest + strategy: + matrix: + include: + - label: 'autoclose in 3 days' + days: 3 + issue_types: 'issues' #issues/pulls/both + replacement_label: '' + closure_message: 'This issue has been automatically closed as it was marked for auto-closure by the team and no additional responses was received within 3 days.' + dry_run: 'false' + - label: 'autoclose in 7 days' + days: 7 + issue_types: 'issues' # issues/pulls/both + replacement_label: '' + closure_message: 'This issue has been automatically closed as it was marked for auto-closure by the team and no additional responses was received within 7 days.' + dry_run: 'false' + steps: + - name: Validate and process ${{ matrix.label }} + uses: actions/github-script@v8 + env: + LABEL_NAME: ${{ matrix.label }} + DAYS_TO_WAIT: ${{ matrix.days }} + AUTHORIZED_USERS: '' + AUTH_MODE: 'write-access' + ISSUE_TYPES: ${{ matrix.issue_types }} + DRY_RUN: ${{ matrix.dry_run }} + REPLACEMENT_LABEL: ${{ matrix.replacement_label }} + CLOSE_MESSAGE: ${{matrix.closure_message}} + with: + script: | + const REQUIRED_PERMISSIONS = ['write', 'admin']; + const CLOSE_MESSAGE = process.env.CLOSE_MESSAGE; + const isDryRun = '${{ inputs.dry_run }}' === 'true' || process.env.DRY_RUN === 'true'; + + const config = { + labelName: process.env.LABEL_NAME, + daysToWait: parseInt(process.env.DAYS_TO_WAIT), + authMode: process.env.AUTH_MODE, + authorizedUsers: process.env.AUTHORIZED_USERS?.split(',').map(u => u.trim()).filter(u => u) || [], + issueTypes: process.env.ISSUE_TYPES, + replacementLabel: process.env.REPLACEMENT_LABEL?.trim() || null + }; + + console.log(`šŸ·ļø Processing label: "${config.labelName}" (${config.daysToWait} days)`); + if (isDryRun) console.log('🧪 DRY-RUN MODE: No actions will be taken'); + + const cutoffDate = new Date(); + cutoffDate.setDate(cutoffDate.getDate() - config.daysToWait); + + async function isAuthorizedUser(username) { + try { + if (config.authMode === 'users') { + return config.authorizedUsers.includes(username); + } else if (config.authMode === 'write-access') { + const { data } = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: username + }); + return REQUIRED_PERMISSIONS.includes(data.permission); + } + } catch (error) { + console.log(`āš ļø Failed to check authorization for ${username}: ${error.message}`); + return false; + } + return false; + } + + let allIssues = []; + let page = 1; + + while (true) { + const { data: issues } = await github.rest.issues.listForRepo({ + owner: context.repo.owner, + repo: context.repo.repo, + state: 'open', + labels: config.labelName, + sort: 'updated', + direction: 'desc', + per_page: 100, + page: page + }); + + if (issues.length === 0) break; + allIssues = allIssues.concat(issues); + if (issues.length < 100) break; + page++; + } + + const targetIssues = allIssues.filter(issue => { + if (config.issueTypes === 'issues' && issue.pull_request) return false; + if (config.issueTypes === 'pulls' && !issue.pull_request) return false; + return true; + }); + + console.log(`šŸ” Found ${targetIssues.length} items with label "${config.labelName}"`); + + if (targetIssues.length === 0) { + console.log('āœ… No items to process'); + return; + } + + let closedCount = 0; + let labelRemovedCount = 0; + let skippedCount = 0; + + for (const issue of targetIssues) { + console.log(`\nšŸ“‹ Processing #${issue.number}: ${issue.title}`); + + try { + const { data: events } = await github.rest.issues.listEvents({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: issue.number + }); + + const labelEvents = events + .filter(e => e.event === 'labeled' && e.label?.name === config.labelName) + .sort((a, b) => new Date(b.created_at) - new Date(a.created_at)); + + if (labelEvents.length === 0) { + console.log(`āš ļø No label events found for #${issue.number}`); + skippedCount++; + continue; + } + + const lastLabelAdded = new Date(labelEvents[0].created_at); + const labelAdder = labelEvents[0].actor.login; + + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: issue.number, + since: lastLabelAdded.toISOString() + }); + + let hasUnauthorizedComment = false; + + for (const comment of comments) { + if (comment.user.login === labelAdder) continue; + + const isAuthorized = await isAuthorizedUser(comment.user.login); + if (!isAuthorized) { + console.log(`āŒ New comment from ${comment.user.login}`); + hasUnauthorizedComment = true; + break; + } + } + + if (hasUnauthorizedComment) { + if (isDryRun) { + console.log(`🧪 DRY-RUN: Would remove ${config.labelName} label from #${issue.number}`); + if (config.replacementLabel) { + console.log(`🧪 DRY-RUN: Would add ${config.replacementLabel} label to #${issue.number}`); + } + } else { + await github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: issue.number, + name: config.labelName + }); + console.log(`šŸ·ļø Removed ${config.labelName} label from #${issue.number}`); + + if (config.replacementLabel) { + await github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: issue.number, + labels: [config.replacementLabel] + }); + console.log(`šŸ·ļø Added ${config.replacementLabel} label to #${issue.number}`); + } + } + labelRemovedCount++; + continue; + } + + if (lastLabelAdded > cutoffDate) { + const daysRemaining = Math.ceil((lastLabelAdded - cutoffDate) / (1000 * 60 * 60 * 24)); + console.log(`ā³ Label added too recently (${daysRemaining} days remaining)`); + skippedCount++; + continue; + } + + if (isDryRun) { + console.log(`🧪 DRY-RUN: Would close #${issue.number} with comment`); + } else { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: issue.number, + body: CLOSE_MESSAGE + }); + + await github.rest.issues.update({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: issue.number, + state: 'closed' + }); + + console.log(`šŸ”’ Closed #${issue.number}`); + } + closedCount++; + } catch (error) { + console.log(`āŒ Error processing #${issue.number}: ${error.message}`); + skippedCount++; + } + } + + console.log(`\nšŸ“Š Summary for "${config.labelName}":`); + if (isDryRun) { + console.log(` 🧪 DRY-RUN MODE - No actual changes made:`); + console.log(` • Issues that would be closed: ${closedCount}`); + console.log(` • Labels that would be removed: ${labelRemovedCount}`); + } else { + console.log(` • Issues closed: ${closedCount}`); + console.log(` • Labels removed: ${labelRemovedCount}`); + } + console.log(` • Issues skipped: ${skippedCount}`); + console.log(` • Total processed: ${targetIssues.length}`); From 500d01aad514fa5f192fe38eff924ed8989446eb Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 12 Sep 2025 13:07:35 -0400 Subject: [PATCH 100/104] fix: Clean up pyproject.toml (#844) --- .pre-commit-config.yaml | 9 +- CONTRIBUTING.md | 11 +-- pyproject.toml | 202 +++++++++++++++++----------------------- 3 files changed, 92 insertions(+), 130 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 37901ae07..e8584a83c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: hooks: - id: hatch-format name: Format code - entry: hatch fmt --formatter + entry: hatch run test-format language: system pass_filenames: false types: [python] @@ -15,13 +15,6 @@ repos: pass_filenames: false types: [python] stages: [pre-commit] - - id: hatch-test-lint - name: Type linting - entry: hatch run test-lint - language: system - pass_filenames: false - types: [ python ] - stages: [ pre-commit ] - id: hatch-test name: Unit tests entry: hatch test diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 93970ed64..d107b1fa8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -44,12 +44,7 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as 1. Entering virtual environment using `hatch` (recommended), then launch your IDE in the new shell. ```bash - hatch shell dev - ``` - - Alternatively, install development dependencies in a manually created virtual environment: - ```bash - pip install -e ".[all]" + hatch shell ``` @@ -73,6 +68,10 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as ```bash hatch test ``` + Or run them with coverage: + ```bash + hatch test -c + ``` 6. Run integration tests: ```bash diff --git a/pyproject.toml b/pyproject.toml index 151a80530..cdf4e9063 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,9 +2,10 @@ requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" + [project] name = "strands-agents" -dynamic = ["version"] +dynamic = ["version"] # Version determined by git tags description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" @@ -38,65 +39,25 @@ dependencies = [ "opentelemetry-instrumentation-threading>=0.51b0,<1.00b0", ] -[project.urls] -Homepage = "https://github.com/strands-agents/sdk-python" -"Bug Tracker" = "https://github.com/strands-agents/sdk-python/issues" -Documentation = "https://strandsagents.com" - -[tool.hatch.build.targets.wheel] -packages = ["src/strands"] [project.optional-dependencies] -anthropic = [ - "anthropic>=0.21.0,<1.0.0", -] -dev = [ - "commitizen>=4.4.0,<5.0.0", - "hatch>=1.0.0,<2.0.0", - "moto>=5.1.0,<6.0.0", - "mypy>=1.15.0,<2.0.0", - "pre-commit>=3.2.0,<4.4.0", - "pytest>=8.0.0,<9.0.0", - "pytest-cov>=6.0.0,<7.0.0", - "pytest-asyncio>=1.0.0,<1.2.0", - "pytest-xdist>=3.0.0,<4.0.0", - "ruff>=0.12.0,<0.14.0", +anthropic = ["anthropic>=0.21.0,<1.0.0"] +litellm = ["litellm>=1.75.9,<2.0.0", "openai>=1.68.0,<1.108.0"] +llamaapi = ["llama-api-client>=0.1.0,<1.0.0"] +mistral = ["mistralai>=1.8.2"] +ollama = ["ollama>=0.4.8,<1.0.0"] +openai = ["openai>=1.68.0,<2.0.0"] +writer = ["writer-sdk>=2.2.0,<3.0.0"] +sagemaker = [ + "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", + "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] +otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ "sphinx>=5.0.0,<6.0.0", "sphinx-rtd-theme>=1.0.0,<2.0.0", "sphinx-autodoc-typehints>=1.12.0,<2.0.0", ] -litellm = [ - "litellm>=1.75.9,<2.0.0", - "openai>=1.68.0,<1.108.0", -] -llamaapi = [ - "llama-api-client>=0.1.0,<1.0.0", -] -mistral = [ - "mistralai>=1.8.2", -] -ollama = [ - "ollama>=0.4.8,<1.0.0", -] -openai = [ - "openai>=1.68.0,<2.0.0", -] -otel = [ - "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", -] -writer = [ - "writer-sdk>=2.2.0,<3.0.0" -] - -sagemaker = [ - "boto3>=1.26.0,<2.0.0", - "botocore>=1.29.0,<2.0.0", - "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", - # uses OpenAI as part of the implementation - "openai>=1.68.0,<2.0.0", -] a2a = [ "a2a-sdk>=0.3.0,<0.4.0", @@ -106,22 +67,46 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = [ - "strands-agents[a2a,anthropic,dev,docs,litellm,llamaapi,mistral,ollama,openai,otel]", +all = ["strands-agents[a2a,anthropic,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] + +dev = [ + "commitizen>=4.4.0,<5.0.0", + "hatch>=1.0.0,<2.0.0", + "moto>=5.1.0,<6.0.0", + "mypy>=1.15.0,<2.0.0", + "pre-commit>=3.2.0,<4.4.0", + "pytest>=8.0.0,<9.0.0", + "pytest-cov>=7.0.0,<8.0.0", + "pytest-asyncio>=1.0.0,<1.2.0", + "pytest-xdist>=3.0.0,<4.0.0", + "ruff>=0.13.0,<0.14.0", ] +[project.urls] +Homepage = "https://github.com/strands-agents/sdk-python" +"Bug Tracker" = "https://github.com/strands-agents/sdk-python/issues" +Documentation = "https://strandsagents.com" + + +[tool.hatch.build.targets.wheel] +packages = ["src/strands"] + + [tool.hatch.version] -# Tells Hatch to use your version control system (git) to determine the version. -source = "vcs" +source = "vcs" # Use git tags for versioning + [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"] +installer = "uv" +features = ["all"] dependencies = [ "mypy>=1.15.0,<2.0.0", - "ruff>=0.11.6,<0.12.0", + "ruff>=0.13.0,<0.14.0", + # Include required pacakge dependencies for mypy "strands-agents @ {root:uri}", ] +# Define static-analysis scripts so we can include mypy as part of the linting check [tool.hatch.envs.hatch-static-analysis.scripts] format-check = [ "ruff format --check" @@ -137,65 +122,54 @@ lint-fix = [ "ruff check --fix" ] + [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"] -extra-dependencies = [ - "moto>=5.1.0,<6.0.0", +installer = "uv" +features = ["all"] +extra-args = ["-n", "auto", "-vv"] +dependencies = [ "pytest>=8.0.0,<9.0.0", - "pytest-cov>=6.0.0,<7.0.0", + "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.2.0", "pytest-xdist>=3.0.0,<4.0.0", + "moto>=5.1.0,<6.0.0", ] -extra-args = [ - "-n", - "auto", - "-vv", -] - -[tool.hatch.envs.dev] -dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a", "sagemaker"] [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] [tool.hatch.envs.hatch-test.scripts] -run = [ - "pytest{env:HATCH_TEST_ARGS:} {args}" -] -run-cov = [ - "pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}" -] - +run = "pytest{env:HATCH_TEST_ARGS:} {args}" # Run with: hatch test +run-cov = "pytest{env:HATCH_TEST_ARGS:} {args} --cov --cov-config=pyproject.toml --cov-report html --cov-report xml {args}" # Run with: hatch test -c cov-combine = [] cov-report = [] -[tool.hatch.envs.default.scripts] -list = [ - "echo 'Scripts commands available for default env:'; hatch env show --json | jq --raw-output '.default.scripts | keys[]'" -] -format = [ - "hatch fmt --formatter", -] -test-format = [ - "hatch fmt --formatter --check", -] -lint = [ - "hatch fmt --linter" -] -test-lint = [ - "hatch fmt --linter --check" -] -test = [ - "hatch test --cover --cov-report html --cov-report xml {args}" -] -test-integ = [ - "hatch test tests_integ {args}" +[tool.hatch.envs.default] +installer = "uv" +dev-mode = true +features = ["all"] +dependencies = [ + "commitizen>=4.4.0,<5.0.0", + "hatch>=1.0.0,<2.0.0", + "pre-commit>=3.2.0,<4.4.0", ] + + +[tool.hatch.envs.default.scripts] +list = "echo 'Scripts commands available for default env:'; hatch env show --json | jq --raw-output '.default.scripts | keys[]'" + +format = "hatch fmt --formatter" +test-format = "hatch fmt --formatter --check" + +lint = "hatch fmt --linter" +test-lint = "hatch fmt --linter --check" + +test = "hatch test {args}" +test-integ = "hatch test tests_integ {args}" + prepare = [ - "hatch fmt --formatter", - "hatch fmt --linter", + "hatch run test-format", "hatch run test-lint", "hatch test --all" ] @@ -216,9 +190,6 @@ warn_unreachable = true follow_untyped_imports = true ignore_missing_imports = false -[[tool.mypy.overrides]] -module = "litellm" -ignore_missing_imports = true [tool.ruff] line-length = 120 @@ -226,12 +197,12 @@ include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/* [tool.ruff.lint] select = [ - "B", # flake8-bugbear - "D", # pydocstyle - "E", # pycodestyle - "F", # pyflakes - "G", # logging format - "I", # isort + "B", # flake8-bugbear + "D", # pydocstyle + "E", # pycodestyle + "F", # pyflakes + "G", # logging format + "I", # isort "LOG", # logging ] @@ -241,12 +212,12 @@ select = [ [tool.ruff.lint.pydocstyle] convention = "google" + [tool.pytest.ini_options] -testpaths = [ - "tests" -] +testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" + [tool.coverage.run] branch = true source = ["src"] @@ -263,13 +234,12 @@ directory = "build/coverage/html" [tool.coverage.xml] output = "build/coverage/coverage.xml" + [tool.commitizen] name = "cz_conventional_commits" tag_format = "v$version" bump_message = "chore(release): bump version $current_version -> $new_version" -version_files = [ - "pyproject.toml:version", -] +version_files = ["pyproject.toml:version"] update_changelog_on_bump = true style = [ ["qmark", "fg:#ff9d00 bold"], From 69d3910ccfbf8b45930964f139d5f2a3ffde1a11 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Fri, 12 Sep 2025 23:18:17 +0200 Subject: [PATCH 101/104] Fixing documentation in decorator.py (#852) The documentation provided for the tool decorator has been updated to work with the version 1.8.0 --- src/strands/tools/decorator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 4923a44ee..99aa7e372 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -36,7 +36,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: } agent = Agent(tools=[my_tool]) - agent.my_tool(param1="hello", param2=123) + agent.tool.my_tool(param1="hello", param2=123) ``` """ From 6ccc8e73636fff929a89793bf470dc511727c480 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 15 Sep 2025 10:23:03 -0400 Subject: [PATCH 102/104] models - openai - use client context (#856) --- src/strands/models/openai.py | 103 ++++++++++++++-------------- tests/strands/models/test_openai.py | 17 ++--- 2 files changed, 58 insertions(+), 62 deletions(-) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index fd75ea175..b80cdddab 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -64,12 +64,10 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: """ validate_config_keys(model_config, self.OpenAIConfig) self.config = dict(model_config) + self.client_args = client_args or {} logger.debug("config=<%s> | initializing", self.config) - client_args = client_args or {} - self.client = openai.AsyncOpenAI(**client_args) - @override def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override] """Update the OpenAI model configuration with the provided arguments. @@ -379,58 +377,60 @@ async def stream( logger.debug("formatted request=<%s>", request) logger.debug("invoking model") - response = await self.client.chat.completions.create(**request) - - logger.debug("got response from model") - yield self.format_chunk({"chunk_type": "message_start"}) - yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) - - tool_calls: dict[int, list[Any]] = {} - - async for event in response: - # Defensive: skip events with empty or missing choices - if not getattr(event, "choices", None): - continue - choice = event.choices[0] - - if choice.delta.content: - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} - ) - - if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: - yield self.format_chunk( - { - "chunk_type": "content_delta", - "data_type": "reasoning_content", - "data": choice.delta.reasoning_content, - } - ) - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) + async with openai.AsyncOpenAI(**self.client_args) as client: + response = await client.chat.completions.create(**request) - if choice.finish_reason: - break + logger.debug("got response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + tool_calls: dict[int, list[Any]] = {} - for tool_deltas in tool_calls.values(): - yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) + async for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue + choice = event.choices[0] + + if choice.delta.content: + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + ) + + if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": choice.delta.reasoning_content, + } + ) - for tool_delta in tool_deltas: - yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + if choice.finish_reason: + break - yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) - # Skip remaining events as we don't have use for anything except the final usage payload - async for event in response: - _ = event + for tool_deltas in tool_calls.values(): + yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) - if event.usage: - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + for tool_delta in tool_deltas: + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) + + # Skip remaining events as we don't have use for anything except the final usage payload + async for event in response: + _ = event + + if event.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) logger.debug("finished streaming response from model") @@ -449,11 +449,12 @@ async def structured_output( Yields: Model events with the last being the structured output. """ - response: ParsedChatCompletion = await self.client.beta.chat.completions.parse( # type: ignore - model=self.get_config()["model_id"], - messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], - response_format=output_model, - ) + async with openai.AsyncOpenAI(**self.client_args) as client: + response: ParsedChatCompletion = await client.beta.chat.completions.parse( + model=self.get_config()["model_id"], + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], + response_format=output_model, + ) parsed: T | None = None # Find the first choice with tool_calls diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 64da3cac2..5979ec628 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -8,14 +8,11 @@ @pytest.fixture -def openai_client_cls(): +def openai_client(): with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls: - yield mock_client_cls - - -@pytest.fixture -def openai_client(openai_client_cls): - return openai_client_cls.return_value + mock_client = unittest.mock.AsyncMock() + mock_client_cls.return_value.__aenter__.return_value = mock_client + yield mock_client @pytest.fixture @@ -68,16 +65,14 @@ class TestOutputModel(pydantic.BaseModel): return TestOutputModel -def test__init__(openai_client_cls, model_id): - model = OpenAIModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) +def test__init__(model_id): + model = OpenAIModel(model_id=model_id, params={"max_tokens": 1}) tru_config = model.get_config() exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} assert tru_config == exp_config - openai_client_cls.assert_called_once_with(api_key="k1") - def test_update_config(model, model_id): model.update_config(model_id=model_id) From 293f00e02d191704ac6ad02c50366a253528a259 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 10 Sep 2025 18:05:39 -0400 Subject: [PATCH 103/104] feat: add update-docs --- .github/workflows/publish-lambda-layer.yml | 186 +++++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 .github/workflows/publish-lambda-layer.yml diff --git a/.github/workflows/publish-lambda-layer.yml b/.github/workflows/publish-lambda-layer.yml new file mode 100644 index 000000000..8384e0fae --- /dev/null +++ b/.github/workflows/publish-lambda-layer.yml @@ -0,0 +1,186 @@ +name: Publish PyPI Package to Lambda Layer + +on: + workflow_dispatch: + inputs: + package_version: + description: 'Package version to download' + required: true + type: string + python_version: + description: 'Python version' + required: false + default: '3.12' + type: choice + options: ['3.10', '3.11', '3.12', '3.13'] + architecture: + description: 'Architecture' + required: false + default: 'x86_64' + type: choice + options: ['x86_64', 'aarch64'] + region: + description: 'AWS region' + required: false + default: 'us-east-1' + type: choice + # Only non opt-in regions included for now + options: ['us-east-1', 'us-east-2', 'us-west-1', 'us-west-2', 'ap-south-1', 'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-southeast-1', 'ap-southeast-2', 'ca-central-1', 'eu-central-1', 'eu-west-1', 'eu-west-2', 'eu-west-3', 'eu-north-1', 'sa-east-1'] + confirm: + description: 'Type "Create Lambda Layer" to confirm publishing the layer' + required: true + type: string + +env: + IS_FULL_DEPLOY: ${{ !inputs.python_version && !inputs.architecture && !inputs.region }} + +jobs: + publish-layer: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ${{ inputs.python_version && fromJson(format('["{}"', inputs.python_version)) || fromJson('["3.10", "3.11", "3.12", "3.13"]') }} + architecture: ${{ inputs.architecture && fromJson(format('["{}"', inputs.architecture)) || fromJson('["x86_64", "aarch64"]') }} + region: ${{ inputs.region && fromJson(format('["{}"', inputs.region)) || fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') }} + + outputs: + layer-version: ${{ env.LAYER_VERSION }} + permissions: + id-token: write + contents: read + + steps: + - name: Validate confirmation + run: | + CONFIRM="${{ inputs.confirm }}" + if [ "$CONFIRM" != "Create Lambda Layer" ]; then + if [[ "$CONFIRM" =~ ^(x86_64|aarch64|3\.[0-9]+|[a-z]+-[a-z]+-[0-9]+)$ ]]; then + echo "Error: You entered '$CONFIRM' which looks like an architecture, Python version, or region." + echo "Please type exactly 'Create Lambda Layer' to confirm." + else + echo "Confirmation failed. You must type exactly 'Create Lambda Layer' to proceed." + fi + exit 1 + fi + echo "Confirmation validated" + + - name: Checkout current repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} + aws-region: ${{ matrix.region }} + + - name: Create layer directory structure + run: | + mkdir -p layer/python + + - name: Download and install package + run: | + pip install strands-agents==${{ inputs.package_version }} \ + --python-version ${{ matrix.python-version }} \ + --platform manylinux2014_${{ matrix.architecture }} \ + -t layer/python/ \ + --only-binary=:all: + + - name: Create layer zip + run: | + cd layer + zip -r ../lambda-layer.zip . + + - name: Upload layer to S3 and publish + run: | + PYTHON_VERSION="${{ matrix.python-version }}" + ARCH="${{ matrix.architecture }}" + REGION="${{ matrix.region }}" + LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" + BUCKET_NAME="strands-agents-lambda-layers-$(aws sts get-caller-identity --query Account --output text)-${REGION}" + LAYER_KEY="$LAYER_NAME/v${{ inputs.package_version }}/lambda-layer.zip" + + if ! aws s3api head-bucket --bucket "$BUCKET_NAME" 2>/dev/null; then + if [ "$REGION" = "us-east-1" ]; then + aws s3api create-bucket --bucket "$BUCKET_NAME" --region "$REGION" + else + aws s3api create-bucket --bucket "$BUCKET_NAME" --region "$REGION" --create-bucket-configuration LocationConstraint="$REGION" + fi + fi + + aws s3 cp lambda-layer.zip "s3://$BUCKET_NAME/$LAYER_KEY" --region "$REGION" + echo "Uploaded layer to s3://$BUCKET_NAME/$LAYER_KEY" + + DESCRIPTION="PyPI package: strands-agents v${{ inputs.package_version }} (Python $PYTHON_VERSION, $ARCH)" + + LAYER_OUTPUT=$(aws lambda publish-layer-version \ + --layer-name $LAYER_NAME \ + --description "$DESCRIPTION" \ + --content S3Bucket=$BUCKET_NAME,S3Key=$LAYER_KEY \ + --compatible-runtimes python${{ matrix.python-version }} \ + --region "$REGION" \ + --license-info Apache-2.0 \ + --output json) + + LAYER_ARN=$(echo "$LAYER_OUTPUT" | jq -r '.LayerArn') + LAYER_VERSION=$(echo "$LAYER_OUTPUT" | jq -r '.Version') + + echo "Published layer version $LAYER_VERSION with ARN: $LAYER_ARN in region $REGION" + + aws lambda add-layer-version-permission \ + --layer-name $LAYER_NAME \ + --version-number $LAYER_VERSION \ + --statement-id public \ + --action lambda:GetLayerVersion \ + --principal '*' \ + --region "$REGION" + + echo "Successfully published layer version $LAYER_VERSION in region $REGION" + + if [ "${{ env.IS_FULL_DEPLOY }}" = "true" ] && [ "$REGION" = "us-east-1" ] && [ "$PYTHON_VERSION" = "3.10" ] && [ "$ARCH" = "x86_64" ]; then + echo "LAYER_VERSION=$LAYER_VERSION" >> $GITHUB_ENV + fi + + update-docs: + if: ${{ env.IS_FULL_DEPLOY == 'true' }} + needs: publish-layer + runs-on: ubuntu-latest + steps: + - name: Checkout docs repository + uses: actions/checkout@v4 + with: + repository: ${{ github.repository_owner }}/docs + token: ${{ secrets.GITHUB_TOKEN }} + path: docs + + - name: Update lambda layers documentation + run: | + cd docs + LAYER_VERSION="${{ needs.publish-layer.outputs.layer-version }}" + NEW_ROW="| $LAYER_VERSION | [${{ inputs.package_version }}](https://pypi.org/project/strands-agents/${{ inputs.package_version }}) | \`arn:aws:lambda:{REGION}:856699698935:layer:strands-agents-{VERSION}-{ARCHITECTURE}:$LAYER_VERSION\` |" + + sed -i "//a\$NEW_ROW" docs/user-guide/deploy/lambda-layers.md + + - name: Create Pull Request + run: | + cd docs + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + BRANCH="update-lambda-layers-${{ inputs.package_version }}" + git checkout -b "$BRANCH" + git add docs/user-guide/deploy/lambda-layers.md + git commit -m "Update lambda layers with version ${{ inputs.package_version }}" + git push origin "$BRANCH" + + gh pr create \ + --title "Update lambda layers documentation for v${{ inputs.package_version }}" \ + --body "Automated update to add new lambda layer version ${{ inputs.package_version }}" \ + --head "$BRANCH" \ + --base main + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file From c0b52ff35cd1d90b414a214961ffb05ca608d637 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 10 Sep 2025 18:17:21 -0400 Subject: [PATCH 104/104] ci: add workflow for lambda layer publish and yank --- .github/workflows/LAMBDA_LAYERS_SOP.md | 43 +++++ .github/workflows/publish-lambda-layer.yml | 186 +++++++++++---------- .github/workflows/yank-lambda-layer.yml | 81 +++++++++ 3 files changed, 223 insertions(+), 87 deletions(-) create mode 100644 .github/workflows/LAMBDA_LAYERS_SOP.md create mode 100644 .github/workflows/yank-lambda-layer.yml diff --git a/.github/workflows/LAMBDA_LAYERS_SOP.md b/.github/workflows/LAMBDA_LAYERS_SOP.md new file mode 100644 index 000000000..4ac96a77d --- /dev/null +++ b/.github/workflows/LAMBDA_LAYERS_SOP.md @@ -0,0 +1,43 @@ +# Lambda Layers Standard Operating Procedures (SOP) + +## Overview + +This document defines the standard operating procedures for managing Strands Agents Lambda layers across all AWS regions, Python versions, and architectures. + +**Total: 136 individual Lambda layers** (17 regions Ɨ 2 architectures Ɨ 4 Python versions). All variants must maintain the same layer version number for each PyPI package version, with only one row per PyPI version appearing in documentation. + +## Deployment Process + +### 1. Initial Deployment +1. Run workflow with ALL options selected (default) +2. Specify PyPI package version +3. Type "Create Lambda Layer {package_version}" to confirm +4. All 136 individual layers deploy in parallel (4 Python Ɨ 2 arch Ɨ 17 regions) +5. Each layer gets its own unique name: `strands-agents-py{PYTHON_VERSION}-{ARCH}` + +### 2. Version Buffering for New Variants +When adding new variants (new Python version, architecture, or region): + +1. **Determine target layer version**: Check existing variants to find the highest layer version +2. **Buffer deployment**: Deploy new variants multiple times until layer version matches existing variants +3. **Example**: If existing variants are at layer version 5, deploy new variant 5 times to reach version 5 + +### 3. Handling Transient Failures +When some regions fail during deployment: + +1. **Identify failed regions**: Check which combinations didn't complete successfully +2. **Targeted redeployment**: Use specific region/arch/Python inputs to redeploy failed combinations +3. **Version alignment**: Continue deploying until all variants reach the same layer version +4. **Verification**: Confirm all combinations have identical layer versions before updating docs + +## Yank Process + +### Yank Procedure +1. Use the `yank_lambda_layer` GitHub action workflow +2. Specify the layer version to yank +3. Type "Yank Lambda Layer {layer_version}" to confirm +4. **Full yank**: Run with ALL options selected (default) to yank all 136 variants OR **Partial yank**: Specify Python versions, architectures, and regions for targeted yanking +6. Update documentation +7. **Communication**: Notify users through appropriate channels + +**Note**: Yanking deletes layer versions completely. Existing Lambda functions using the layer continue to work, but new functions cannot use the yanked version. \ No newline at end of file diff --git a/.github/workflows/publish-lambda-layer.yml b/.github/workflows/publish-lambda-layer.yml index 8384e0fae..b4bceca83 100644 --- a/.github/workflows/publish-lambda-layer.yml +++ b/.github/workflows/publish-lambda-layer.yml @@ -9,64 +9,90 @@ on: type: string python_version: description: 'Python version' - required: false - default: '3.12' + required: true + default: 'ALL' type: choice - options: ['3.10', '3.11', '3.12', '3.13'] + options: ['ALL', '3.10', '3.11', '3.12', '3.13'] architecture: description: 'Architecture' - required: false - default: 'x86_64' + required: true + default: 'ALL' type: choice - options: ['x86_64', 'aarch64'] + options: ['ALL', 'x86_64', 'aarch64'] region: description: 'AWS region' - required: false - default: 'us-east-1' + required: true + default: 'ALL' type: choice # Only non opt-in regions included for now - options: ['us-east-1', 'us-east-2', 'us-west-1', 'us-west-2', 'ap-south-1', 'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-southeast-1', 'ap-southeast-2', 'ca-central-1', 'eu-central-1', 'eu-west-1', 'eu-west-2', 'eu-west-3', 'eu-north-1', 'sa-east-1'] + options: ['ALL', 'us-east-1', 'us-east-2', 'us-west-1', 'us-west-2', 'ap-south-1', 'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-southeast-1', 'ap-southeast-2', 'ca-central-1', 'eu-central-1', 'eu-west-1', 'eu-west-2', 'eu-west-3', 'eu-north-1', 'sa-east-1'] confirm: - description: 'Type "Create Lambda Layer" to confirm publishing the layer' + description: 'Type "Create Lambda Layer {PyPI version}" to confirm publishing the layer' required: true type: string env: - IS_FULL_DEPLOY: ${{ !inputs.python_version && !inputs.architecture && !inputs.region }} + BUCKET_NAME: strands-agents-lambda-layer jobs: - publish-layer: + validate: + runs-on: ubuntu-latest + steps: + - name: Validate confirmation + run: | + CONFIRM="${{ inputs.confirm }}" + EXPECTED="Create Lambda Layer ${{ inputs.package_version }}" + if [ "$CONFIRM" != "$EXPECTED" ]; then + echo "Confirmation failed. You must type exactly '$EXPECTED' to proceed." + exit 1 + fi + echo "Confirmation validated" + + create-buckets: + needs: validate runs-on: ubuntu-latest strategy: matrix: - python-version: ${{ inputs.python_version && fromJson(format('["{}"', inputs.python_version)) || fromJson('["3.10", "3.11", "3.12", "3.13"]') }} - architecture: ${{ inputs.architecture && fromJson(format('["{}"', inputs.architecture)) || fromJson('["x86_64", "aarch64"]') }} - region: ${{ inputs.region && fromJson(format('["{}"', inputs.region)) || fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') }} - - outputs: - layer-version: ${{ env.LAYER_VERSION }} + region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} permissions: id-token: write - contents: read - steps: - - name: Validate confirmation + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} + aws-region: ${{ matrix.region }} + + - name: Create S3 bucket run: | - CONFIRM="${{ inputs.confirm }}" - if [ "$CONFIRM" != "Create Lambda Layer" ]; then - if [[ "$CONFIRM" =~ ^(x86_64|aarch64|3\.[0-9]+|[a-z]+-[a-z]+-[0-9]+)$ ]]; then - echo "Error: You entered '$CONFIRM' which looks like an architecture, Python version, or region." - echo "Please type exactly 'Create Lambda Layer' to confirm." + REGION="${{ matrix.region }}" + ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) + REGIONAL_BUCKET="${{ env.BUCKET_NAME }}-${ACCOUNT_ID}-${REGION}" + + if ! aws s3api head-bucket --bucket "$REGIONAL_BUCKET" 2>/dev/null; then + if [ "$REGION" = "us-east-1" ]; then + aws s3api create-bucket --bucket "$REGIONAL_BUCKET" --region "$REGION" 2>/dev/null || echo "Bucket $REGIONAL_BUCKET already exists" else - echo "Confirmation failed. You must type exactly 'Create Lambda Layer' to proceed." + aws s3api create-bucket --bucket "$REGIONAL_BUCKET" --region "$REGION" --create-bucket-configuration LocationConstraint="$REGION" 2>/dev/null || echo "Bucket $REGIONAL_BUCKET already exists" fi - exit 1 + echo "S3 bucket ready: $REGIONAL_BUCKET" + else + echo "S3 bucket already exists: $REGIONAL_BUCKET" fi - echo "Confirmation validated" - - name: Checkout current repository - uses: actions/checkout@v4 + package-and-upload: + needs: create-buckets + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ${{ inputs.python_version == 'ALL' && fromJson('["3.10", "3.11", "3.12", "3.13"]') || fromJson(format('["{0}"]', inputs.python_version)) }} + architecture: ${{ inputs.architecture == 'ALL' && fromJson('["x86_64", "aarch64"]') || fromJson(format('["{0}"]', inputs.architecture)) }} + region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} + + permissions: + id-token: write + steps: - name: Set up Python uses: actions/setup-python@v4 with: @@ -94,34 +120,64 @@ jobs: run: | cd layer zip -r ../lambda-layer.zip . - - - name: Upload layer to S3 and publish + + - name: Upload to S3 run: | PYTHON_VERSION="${{ matrix.python-version }}" ARCH="${{ matrix.architecture }}" REGION="${{ matrix.region }}" LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" - BUCKET_NAME="strands-agents-lambda-layers-$(aws sts get-caller-identity --query Account --output text)-${REGION}" + ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) + BUCKET_NAME="${{ env.BUCKET_NAME }}-${ACCOUNT_ID}-${REGION}" LAYER_KEY="$LAYER_NAME/v${{ inputs.package_version }}/lambda-layer.zip" - if ! aws s3api head-bucket --bucket "$BUCKET_NAME" 2>/dev/null; then - if [ "$REGION" = "us-east-1" ]; then - aws s3api create-bucket --bucket "$BUCKET_NAME" --region "$REGION" - else - aws s3api create-bucket --bucket "$BUCKET_NAME" --region "$REGION" --create-bucket-configuration LocationConstraint="$REGION" - fi - fi - aws s3 cp lambda-layer.zip "s3://$BUCKET_NAME/$LAYER_KEY" --region "$REGION" echo "Uploaded layer to s3://$BUCKET_NAME/$LAYER_KEY" + publish-layer: + needs: package-and-upload + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ${{ inputs.python_version == 'ALL' && fromJson('["3.10", "3.11", "3.12", "3.13"]') || fromJson(format('["{0}"]', inputs.python_version)) }} + architecture: ${{ inputs.architecture == 'ALL' && fromJson('["x86_64", "aarch64"]') || fromJson(format('["{0}"]', inputs.architecture)) }} + region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} + + permissions: + id-token: write + + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} + aws-region: ${{ matrix.region }} + + - name: Publish layer + run: | + PYTHON_VERSION="${{ matrix.python-version }}" + ARCH="${{ matrix.architecture }}" + REGION="${{ matrix.region }}" + LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" + ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) + REGION_BUCKET="${{ env.BUCKET_NAME }}-${ACCOUNT_ID}-${REGION}" + LAYER_KEY="$LAYER_NAME/v${{ inputs.package_version }}/lambda-layer.zip" + DESCRIPTION="PyPI package: strands-agents v${{ inputs.package_version }} (Python $PYTHON_VERSION, $ARCH)" + # Set compatible architecture based on matrix architecture + if [ "$ARCH" = "x86_64" ]; then + COMPATIBLE_ARCH="x86_64" + else + COMPATIBLE_ARCH="arm64" + fi + LAYER_OUTPUT=$(aws lambda publish-layer-version \ --layer-name $LAYER_NAME \ --description "$DESCRIPTION" \ - --content S3Bucket=$BUCKET_NAME,S3Key=$LAYER_KEY \ + --content S3Bucket=$REGION_BUCKET,S3Key=$LAYER_KEY \ --compatible-runtimes python${{ matrix.python-version }} \ + --compatible-architectures $COMPATIBLE_ARCH \ --region "$REGION" \ --license-info Apache-2.0 \ --output json) @@ -140,47 +196,3 @@ jobs: --region "$REGION" echo "Successfully published layer version $LAYER_VERSION in region $REGION" - - if [ "${{ env.IS_FULL_DEPLOY }}" = "true" ] && [ "$REGION" = "us-east-1" ] && [ "$PYTHON_VERSION" = "3.10" ] && [ "$ARCH" = "x86_64" ]; then - echo "LAYER_VERSION=$LAYER_VERSION" >> $GITHUB_ENV - fi - - update-docs: - if: ${{ env.IS_FULL_DEPLOY == 'true' }} - needs: publish-layer - runs-on: ubuntu-latest - steps: - - name: Checkout docs repository - uses: actions/checkout@v4 - with: - repository: ${{ github.repository_owner }}/docs - token: ${{ secrets.GITHUB_TOKEN }} - path: docs - - - name: Update lambda layers documentation - run: | - cd docs - LAYER_VERSION="${{ needs.publish-layer.outputs.layer-version }}" - NEW_ROW="| $LAYER_VERSION | [${{ inputs.package_version }}](https://pypi.org/project/strands-agents/${{ inputs.package_version }}) | \`arn:aws:lambda:{REGION}:856699698935:layer:strands-agents-{VERSION}-{ARCHITECTURE}:$LAYER_VERSION\` |" - - sed -i "//a\$NEW_ROW" docs/user-guide/deploy/lambda-layers.md - - - name: Create Pull Request - run: | - cd docs - git config user.name "github-actions[bot]" - git config user.email "github-actions[bot]@users.noreply.github.com" - - BRANCH="update-lambda-layers-${{ inputs.package_version }}" - git checkout -b "$BRANCH" - git add docs/user-guide/deploy/lambda-layers.md - git commit -m "Update lambda layers with version ${{ inputs.package_version }}" - git push origin "$BRANCH" - - gh pr create \ - --title "Update lambda layers documentation for v${{ inputs.package_version }}" \ - --body "Automated update to add new lambda layer version ${{ inputs.package_version }}" \ - --head "$BRANCH" \ - --base main - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/yank-lambda-layer.yml b/.github/workflows/yank-lambda-layer.yml new file mode 100644 index 000000000..27927a862 --- /dev/null +++ b/.github/workflows/yank-lambda-layer.yml @@ -0,0 +1,81 @@ +name: Yank Lambda Layer + +on: + workflow_dispatch: + inputs: + layer_version: + description: 'Layer version to yank' + required: true + type: string + python_version: + description: 'Python version' + required: true + default: 'ALL' + type: choice + options: ['ALL', '3.10', '3.11', '3.12', '3.13'] + architecture: + description: 'Architecture' + required: true + default: 'ALL' + type: choice + options: ['ALL', 'x86_64', 'aarch64'] + region: + description: 'AWS region' + required: true + default: 'ALL' + type: choice + # Only non opt-in regions included for now + options: ['ALL', 'us-east-1', 'us-east-2', 'us-west-1', 'us-west-2', 'ap-south-1', 'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-southeast-1', 'ap-southeast-2', 'ca-central-1', 'eu-central-1', 'eu-west-1', 'eu-west-2', 'eu-west-3', 'eu-north-1', 'sa-east-1'] + confirm: + description: 'Type "Yank Lambda Layer {layer version}" to confirm yanking the layer' + required: true + type: string + +jobs: + yank-layer: + runs-on: ubuntu-latest + continue-on-error: true + strategy: + fail-fast: false + matrix: + python-version: ${{ inputs.python_version == 'ALL' && fromJson('["3.10", "3.11", "3.12", "3.13"]') || fromJson(format('["{0}"]', inputs.python_version)) }} + architecture: ${{ inputs.architecture == 'ALL' && fromJson('["x86_64", "aarch64"]') || fromJson(format('["{0}"]', inputs.architecture)) }} + region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} + + permissions: + id-token: write + + steps: + - name: Validate confirmation + run: | + CONFIRM="${{ inputs.confirm }}" + EXPECTED="Yank Lambda Layer ${{ inputs.layer_version }}" + if [ "$CONFIRM" != "$EXPECTED" ]; then + echo "Confirmation failed. You must type exactly '$EXPECTED' to proceed." + exit 1 + fi + echo "Confirmation validated" + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} + aws-region: ${{ matrix.region }} + + - name: Yank layer + run: | + PYTHON_VERSION="${{ matrix.python-version }}" + ARCH="${{ matrix.architecture }}" + REGION="${{ matrix.region }}" + LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" + LAYER_VERSION="${{ inputs.layer_version }}" + + echo "Attempting to yank layer $LAYER_NAME version $LAYER_VERSION in region $REGION" + + # Delete the layer version completely + aws lambda delete-layer-version \ + --layer-name $LAYER_NAME \ + --version-number $LAYER_VERSION \ + --region "$REGION" + + echo "Completed yank attempt for layer $LAYER_NAME version $LAYER_VERSION in region $REGION" \ No newline at end of file