From dd76000f5455a5c22dd9ec6fc406f4a08871777f Mon Sep 17 00:00:00 2001
From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com>
Date: Mon, 21 Jul 2025 11:17:52 -0400
Subject: [PATCH 001/104] Use strands logo that looks good in dark & light mode
(#505)
Similar to strands-agents/sdk-python/pull/475 but using a dedicated github icon.
The github icon is the lite logo but copied/renamed to make it dedicated to github
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index c31048770..58c647f8d 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
From 24ccb00159c4319cfb5fd3bea4caa5b50c846539 Mon Sep 17 00:00:00 2001
From: Jeremiah
Date: Tue, 22 Jul 2025 11:48:26 -0400
Subject: [PATCH 002/104] deps(a2a): address interface changes and bump min
version (#515)
Co-authored-by: jer
---
pyproject.toml | 4 ++--
src/strands/multiagent/a2a/server.py | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 974ff9d94..765e815ef 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -90,7 +90,7 @@ writer = [
]
a2a = [
- "a2a-sdk[sql]>=0.2.11,<1.0.0",
+ "a2a-sdk[sql]>=0.2.16,<1.0.0",
"uvicorn>=0.34.2,<1.0.0",
"httpx>=0.28.1,<1.0.0",
"fastapi>=0.115.12,<1.0.0",
@@ -136,7 +136,7 @@ all = [
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
# a2a
- "a2a-sdk[sql]>=0.2.11,<1.0.0",
+ "a2a-sdk[sql]>=0.2.16,<1.0.0",
"uvicorn>=0.34.2,<1.0.0",
"httpx>=0.28.1,<1.0.0",
"fastapi>=0.115.12,<1.0.0",
diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py
index 568252597..de891499d 100644
--- a/src/strands/multiagent/a2a/server.py
+++ b/src/strands/multiagent/a2a/server.py
@@ -83,8 +83,8 @@ def public_agent_card(self) -> AgentCard:
url=self.http_url,
version=self.version,
skills=self.agent_skills,
- defaultInputModes=["text"],
- defaultOutputModes=["text"],
+ default_input_modes=["text"],
+ default_output_modes=["text"],
capabilities=self.capabilities,
)
From 69053420de6695ffc3921481eba04935735f55e3 Mon Sep 17 00:00:00 2001
From: Dean Schmigelski
Date: Tue, 22 Jul 2025 12:45:39 -0400
Subject: [PATCH 003/104] ci: expose STRANDS_TEST_API_KEYS_SECRET_NAME to
integration tests (#513)
---
.github/workflows/integration-test.yml | 1 +
1 file changed, 1 insertion(+)
diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml
index a1d86364a..c347e3805 100644
--- a/.github/workflows/integration-test.yml
+++ b/.github/workflows/integration-test.yml
@@ -67,6 +67,7 @@ jobs:
env:
AWS_REGION: us-east-1
AWS_REGION_NAME: us-east-1 # Needed for LiteLLM
+ STRANDS_TEST_API_KEYS_SECRET_NAME: ${{ secrets.STRANDS_TEST_API_KEYS_SECRET_NAME }}
id: tests
run: |
hatch test tests_integ
From 5a7076bfbd01c415fee1c2ec2316c005da9d973a Mon Sep 17 00:00:00 2001
From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com>
Date: Tue, 22 Jul 2025 14:22:17 -0400
Subject: [PATCH 004/104] Don't re-run workflows on un/approvals (#516)
These were necessary when we had conditional running but we switched to needing to approve all workflows for non-maintainers, so we no longer need these.
Co-authored-by: Mackenzie Zastrow
---
.github/workflows/pr-and-push.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml
index 2b2d026f4..b558943dd 100644
--- a/.github/workflows/pr-and-push.yml
+++ b/.github/workflows/pr-and-push.yml
@@ -3,7 +3,7 @@ name: Pull Request and Push Action
on:
pull_request: # Safer than pull_request_target for untrusted code
branches: [ main ]
- types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed]
+ types: [opened, synchronize, reopened, ready_for_review]
push:
branches: [ main ] # Also run on direct pushes to main
concurrency:
From 9aba0189abf43136a9c3eb477ee5257f735730c9 Mon Sep 17 00:00:00 2001
From: Didier Durand
Date: Tue, 22 Jul 2025 21:49:29 +0200
Subject: [PATCH 005/104] Fixing some typos in various texts (#487)
---
.../conversation_manager/conversation_manager.py | 2 +-
src/strands/multiagent/a2a/executor.py | 2 +-
src/strands/session/repository_session_manager.py | 14 +++++++-------
src/strands/types/session.py | 4 ++--
4 files changed, 11 insertions(+), 11 deletions(-)
diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py
index 8756a1022..2c1ee7847 100644
--- a/src/strands/agent/conversation_manager/conversation_manager.py
+++ b/src/strands/agent/conversation_manager/conversation_manager.py
@@ -36,7 +36,7 @@ def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]
Args:
state: Previous state of the conversation manager
Returns:
- Optional list of messages to prepend to the agents messages. By defualt returns None.
+ Optional list of messages to prepend to the agents messages. By default returns None.
"""
if state.get("__name__") != self.__class__.__name__:
raise ValueError("Invalid conversation manager state.")
diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py
index 00eb4764f..d65c64aff 100644
--- a/src/strands/multiagent/a2a/executor.py
+++ b/src/strands/multiagent/a2a/executor.py
@@ -4,7 +4,7 @@
to be used as an executor in the A2A protocol. It handles the execution of agent
requests and the conversion of Strands Agent streamed responses to A2A events.
-The A2A AgentExecutor ensures clients recieve responses for synchronous and
+The A2A AgentExecutor ensures clients receive responses for synchronous and
streamed requests to the A2AServer.
"""
diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py
index 487335ac9..18a6ac474 100644
--- a/src/strands/session/repository_session_manager.py
+++ b/src/strands/session/repository_session_manager.py
@@ -32,7 +32,7 @@ def __init__(self, session_id: str, session_repository: SessionRepository, **kwa
Args:
session_id: ID to use for the session. A new session with this id will be created if it does
- not exist in the reposiory yet
+ not exist in the repository yet
session_repository: Underlying session repository to use to store the sessions state.
**kwargs: Additional keyword arguments for future extensibility.
@@ -133,15 +133,15 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
agent.state = AgentState(session_agent.state)
# Restore the conversation manager to its previous state, and get the optional prepend messages
- prepend_messsages = agent.conversation_manager.restore_from_session(
+ prepend_messages = agent.conversation_manager.restore_from_session(
session_agent.conversation_manager_state
)
- if prepend_messsages is None:
- prepend_messsages = []
+ if prepend_messages is None:
+ prepend_messages = []
# List the messages currently in the session, using an offset of the messages previously removed
- # by the converstaion manager.
+ # by the conversation manager.
session_messages = self.session_repository.list_messages(
session_id=self.session_id,
agent_id=agent.agent_id,
@@ -150,5 +150,5 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
if len(session_messages) > 0:
self._latest_agent_message[agent.agent_id] = session_messages[-1]
- # Resore the agents messages array including the optional prepend messages
- agent.messages = prepend_messsages + [session_message.to_message() for session_message in session_messages]
+ # Restore the agents messages array including the optional prepend messages
+ agent.messages = prepend_messages + [session_message.to_message() for session_message in session_messages]
diff --git a/src/strands/types/session.py b/src/strands/types/session.py
index 259ab1171..e51816f74 100644
--- a/src/strands/types/session.py
+++ b/src/strands/types/session.py
@@ -125,7 +125,7 @@ def from_agent(cls, agent: "Agent") -> "SessionAgent":
@classmethod
def from_dict(cls, env: dict[str, Any]) -> "SessionAgent":
- """Initialize a SessionAgent from a dictionary, ignoring keys that are not calss parameters."""
+ """Initialize a SessionAgent from a dictionary, ignoring keys that are not class parameters."""
return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters})
def to_dict(self) -> dict[str, Any]:
@@ -144,7 +144,7 @@ class Session:
@classmethod
def from_dict(cls, env: dict[str, Any]) -> "Session":
- """Initialize a Session from a dictionary, ignoring keys that are not calss parameters."""
+ """Initialize a Session from a dictionary, ignoring keys that are not class parameters."""
return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters})
def to_dict(self) -> dict[str, Any]:
From 040ba21cdfeb5dfbcdbb6e76ec227356a4429329 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=2E/c=C2=B2?=
Date: Tue, 22 Jul 2025 15:52:35 -0400
Subject: [PATCH 006/104] docs(readme): add hot reloading documentation for
load_tools_from_directory (#517)
- Add new section showcasing Agent(load_tools_from_directory=True) functionality
- Document automatic tool loading and reloading from ./tools/ directory
- Include practical code example for developers
- Improve discoverability of this development feature
---
README.md | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/README.md b/README.md
index 58c647f8d..62ed54d47 100644
--- a/README.md
+++ b/README.md
@@ -91,6 +91,17 @@ agent = Agent(tools=[word_count])
response = agent("How many words are in this sentence?")
```
+**Hot Reloading from Directory:**
+Enable automatic tool loading and reloading from the `./tools/` directory:
+
+```python
+from strands import Agent
+
+# Agent will watch ./tools/ directory for changes
+agent = Agent(load_tools_from_directory=True)
+response = agent("Use any tools you find in the tools directory")
+```
+
### MCP Support
Seamlessly integrate Model Context Protocol (MCP) servers:
From 022ec556d7eed2de935deb8293e86f8263056af5 Mon Sep 17 00:00:00 2001
From: Dean Schmigelski
Date: Tue, 22 Jul 2025 16:19:15 -0400
Subject: [PATCH 007/104] ci: enable integ tests for anthropic, cohere,
mistral, openai, writer (#510)
---
tests_integ/conftest.py | 52 +++++++++++++++++++
tests_integ/models/providers.py | 4 +-
.../{conformance.py => test_conformance.py} | 4 +-
tests_integ/models/test_model_anthropic.py | 13 +++--
tests_integ/models/test_model_cohere.py | 2 +-
5 files changed, 67 insertions(+), 8 deletions(-)
rename tests_integ/models/{conformance.py => test_conformance.py} (81%)
diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py
index f83f0e299..61c2bf9a1 100644
--- a/tests_integ/conftest.py
+++ b/tests_integ/conftest.py
@@ -1,5 +1,17 @@
+import json
+import logging
+import os
+
+import boto3
import pytest
+logger = logging.getLogger(__name__)
+
+
+def pytest_sessionstart(session):
+ _load_api_keys_from_secrets_manager()
+
+
## Data
@@ -28,3 +40,43 @@ async def alist(items):
return [item async for item in items]
return alist
+
+
+## Models
+
+
+def _load_api_keys_from_secrets_manager():
+ """Load API keys as environment variables from AWS Secrets Manager."""
+ session = boto3.session.Session()
+ client = session.client(service_name="secretsmanager")
+ if "STRANDS_TEST_API_KEYS_SECRET_NAME" in os.environ:
+ try:
+ secret_name = os.getenv("STRANDS_TEST_API_KEYS_SECRET_NAME")
+ response = client.get_secret_value(SecretId=secret_name)
+
+ if "SecretString" in response:
+ secret = json.loads(response["SecretString"])
+ for key, value in secret.items():
+ os.environ[f"{key.upper()}_API_KEY"] = str(value)
+
+ except Exception as e:
+ logger.warning("Error retrieving secret", e)
+
+ """
+ Validate that required environment variables are set when running in GitHub Actions.
+ This prevents tests from being unintentionally skipped due to missing credentials.
+ """
+ if os.environ.get("GITHUB_ACTIONS") != "true":
+ logger.warning("Tests running outside GitHub Actions, skipping required provider validation")
+ return
+
+ required_providers = {
+ "ANTHROPIC_API_KEY",
+ "COHERE_API_KEY",
+ "MISTRAL_API_KEY",
+ "OPENAI_API_KEY",
+ "WRITER_API_KEY",
+ }
+ for provider in required_providers:
+ if provider not in os.environ or not os.environ[provider]:
+ raise ValueError(f"Missing required environment variables for {provider}")
diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py
index 543f58480..d2ac148d3 100644
--- a/tests_integ/models/providers.py
+++ b/tests_integ/models/providers.py
@@ -72,11 +72,11 @@ def __init__(self):
bedrock = ProviderInfo(id="bedrock", factory=lambda: BedrockModel())
cohere = ProviderInfo(
id="cohere",
- environment_variable="CO_API_KEY",
+ environment_variable="COHERE_API_KEY",
factory=lambda: OpenAIModel(
client_args={
"base_url": "https://api.cohere.com/compatibility/v1",
- "api_key": os.getenv("CO_API_KEY"),
+ "api_key": os.getenv("COHERE_API_KEY"),
},
model_id="command-a-03-2025",
params={"stream_options": None},
diff --git a/tests_integ/models/conformance.py b/tests_integ/models/test_conformance.py
similarity index 81%
rename from tests_integ/models/conformance.py
rename to tests_integ/models/test_conformance.py
index 262e41e42..d9875bc07 100644
--- a/tests_integ/models/conformance.py
+++ b/tests_integ/models/test_conformance.py
@@ -1,6 +1,6 @@
import pytest
-from strands.types.models import Model
+from strands.models import Model
from tests_integ.models.providers import ProviderInfo, all_providers
@@ -9,7 +9,7 @@ def get_models():
pytest.param(
provider_info,
id=provider_info.id, # Adds the provider name to the test name
- marks=[provider_info.mark], # ignores tests that don't have the requirements
+ marks=provider_info.mark, # ignores tests that don't have the requirements
)
for provider_info in all_providers
]
diff --git a/tests_integ/models/test_model_anthropic.py b/tests_integ/models/test_model_anthropic.py
index 2ee5e7f23..62a95d06d 100644
--- a/tests_integ/models/test_model_anthropic.py
+++ b/tests_integ/models/test_model_anthropic.py
@@ -6,10 +6,17 @@
import strands
from strands import Agent
from strands.models.anthropic import AnthropicModel
-from tests_integ.models import providers
-# these tests only run if we have the anthropic api key
-pytestmark = providers.anthropic.mark
+"""
+These tests only run if we have the anthropic api key
+
+Because of infrequent burst usage, Anthropic tests are unreliable, failing tests with 529s.
+{'type': 'error', 'error': {'details': None, 'type': 'overloaded_error', 'message': 'Overloaded'}}
+https://docs.anthropic.com/en/api/errors#http-errors
+"""
+pytestmark = pytest.skip(
+ "Because of infrequent burst usage, Anthropic tests are unreliable, failing with 529s", allow_module_level=True
+)
@pytest.fixture
diff --git a/tests_integ/models/test_model_cohere.py b/tests_integ/models/test_model_cohere.py
index 996b0f326..33fb1a8c6 100644
--- a/tests_integ/models/test_model_cohere.py
+++ b/tests_integ/models/test_model_cohere.py
@@ -16,7 +16,7 @@ def model():
return OpenAIModel(
client_args={
"base_url": "https://api.cohere.com/compatibility/v1",
- "api_key": os.getenv("CO_API_KEY"),
+ "api_key": os.getenv("COHERE_API_KEY"),
},
model_id="command-a-03-2025",
params={"stream_options": None},
From e597e07f06665292c4207270f41eb37cc45fd645 Mon Sep 17 00:00:00 2001
From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com>
Date: Wed, 23 Jul 2025 11:26:30 -0400
Subject: [PATCH 008/104] Automatically flatten nested tool collections (#508)
Fixes issue #50
Customers naturally want to pass nested collections of tools - the above issue has gathered enough data points proving that.
---
src/strands/tools/registry.py | 11 +++++++++--
tests/strands/agent/test_agent.py | 19 +++++++++++++++++++
tests/strands/tools/test_registry.py | 27 +++++++++++++++++++++++++++
3 files changed, 55 insertions(+), 2 deletions(-)
diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py
index 9d835d28e..fd395ae77 100644
--- a/src/strands/tools/registry.py
+++ b/src/strands/tools/registry.py
@@ -11,7 +11,7 @@
from importlib import import_module, util
from os.path import expanduser
from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional
from typing_extensions import TypedDict, cast
@@ -54,7 +54,7 @@ def process_tools(self, tools: List[Any]) -> List[str]:
"""
tool_names = []
- for tool in tools:
+ def add_tool(tool: Any) -> None:
# Case 1: String file path
if isinstance(tool, str):
# Extract tool name from path
@@ -97,9 +97,16 @@ def process_tools(self, tools: List[Any]) -> List[str]:
elif isinstance(tool, AgentTool):
self.register_tool(tool)
tool_names.append(tool.tool_name)
+ # Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool
+ elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)):
+ for t in tool:
+ add_tool(t)
else:
logger.warning("tool=<%s> | unrecognized tool specification", tool)
+ for a_tool in tools:
+ add_tool(a_tool)
+
return tool_names
def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None:
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index d6471a09a..4e310dace 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -231,6 +231,25 @@ def test_agent__init__with_string_model_id():
assert agent.model.config["model_id"] == "nonsense"
+def test_agent__init__nested_tools_flattening(tool_decorated, tool_module, tool_imported, tool_registry):
+ _ = tool_registry
+ # Nested structure: [tool_decorated, [tool_module, [tool_imported]]]
+ agent = Agent(tools=[tool_decorated, [tool_module, [tool_imported]]])
+ tru_tool_names = sorted(agent.tool_names)
+ exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"]
+ assert tru_tool_names == exp_tool_names
+
+
+def test_agent__init__deeply_nested_tools(tool_decorated, tool_module, tool_imported, tool_registry):
+ _ = tool_registry
+ # Deeply nested structure
+ nested_tools = [[[[tool_decorated]], [[tool_module]], tool_imported]]
+ agent = Agent(tools=nested_tools)
+ tru_tool_names = sorted(agent.tool_names)
+ exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"]
+ assert tru_tool_names == exp_tool_names
+
+
def test_agent__call__(
mock_model,
system_prompt,
diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py
index ebcba3fb1..66494c987 100644
--- a/tests/strands/tools/test_registry.py
+++ b/tests/strands/tools/test_registry.py
@@ -93,3 +93,30 @@ def tool_function_4(d):
assert len(tools) == 2
assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools)
+
+
+def test_process_tools_flattens_lists_and_tuples_and_sets():
+ def function() -> str:
+ return "done"
+
+ tool_a = tool(name="tool_a")(function)
+ tool_b = tool(name="tool_b")(function)
+ tool_c = tool(name="tool_c")(function)
+ tool_d = tool(name="tool_d")(function)
+ tool_e = tool(name="tool_e")(function)
+ tool_f = tool(name="tool_f")(function)
+
+ registry = ToolRegistry()
+
+ all_tools = [tool_a, (tool_b, tool_c), [{tool_d, tool_e}, [tool_f]]]
+
+ tru_tool_names = sorted(registry.process_tools(all_tools))
+ exp_tool_names = [
+ "tool_a",
+ "tool_b",
+ "tool_c",
+ "tool_d",
+ "tool_e",
+ "tool_f",
+ ]
+ assert tru_tool_names == exp_tool_names
From 4f4e5efd6730fd05ae4382d5ab1715e7b363be6c Mon Sep 17 00:00:00 2001
From: Jeremiah
Date: Wed, 23 Jul 2025 13:44:47 -0400
Subject: [PATCH 009/104] feat(a2a): support mounts for containerized
deployments (#524)
* feat(a2a): support mounts for containerized deployments
* feat(a2a): escape hatch for load balancers which strip paths
* feat(a2a): formatting
---------
Co-authored-by: jer
---
src/strands/multiagent/a2a/server.py | 75 +++-
.../session/repository_session_manager.py | 4 +-
tests/strands/multiagent/a2a/test_server.py | 343 ++++++++++++++++++
3 files changed, 412 insertions(+), 10 deletions(-)
diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py
index de891499d..fa7b6b887 100644
--- a/src/strands/multiagent/a2a/server.py
+++ b/src/strands/multiagent/a2a/server.py
@@ -6,6 +6,7 @@
import logging
from typing import Any, Literal
+from urllib.parse import urlparse
import uvicorn
from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication
@@ -31,6 +32,8 @@ def __init__(
# AgentCard
host: str = "0.0.0.0",
port: int = 9000,
+ http_url: str | None = None,
+ serve_at_root: bool = False,
version: str = "0.0.1",
skills: list[AgentSkill] | None = None,
):
@@ -40,13 +43,34 @@ def __init__(
agent: The Strands Agent to wrap with A2A compatibility.
host: The hostname or IP address to bind the A2A server to. Defaults to "0.0.0.0".
port: The port to bind the A2A server to. Defaults to 9000.
+ http_url: The public HTTP URL where this agent will be accessible. If provided,
+ this overrides the generated URL from host/port and enables automatic
+ path-based mounting for load balancer scenarios.
+ Example: "http://my-alb.amazonaws.com/agent1"
+ serve_at_root: If True, forces the server to serve at root path regardless of
+ http_url path component. Use this when your load balancer strips path prefixes.
+ Defaults to False.
version: The version of the agent. Defaults to "0.0.1".
skills: The list of capabilities or functions the agent can perform.
"""
self.host = host
self.port = port
- self.http_url = f"http://{self.host}:{self.port}/"
self.version = version
+
+ if http_url:
+ # Parse the provided URL to extract components for mounting
+ self.public_base_url, self.mount_path = self._parse_public_url(http_url)
+ self.http_url = http_url.rstrip("/") + "/"
+
+ # Override mount path if serve_at_root is requested
+ if serve_at_root:
+ self.mount_path = ""
+ else:
+ # Fall back to constructing the URL from host and port
+ self.public_base_url = f"http://{host}:{port}"
+ self.http_url = f"{self.public_base_url}/"
+ self.mount_path = ""
+
self.strands_agent = agent
self.name = self.strands_agent.name
self.description = self.strands_agent.description
@@ -58,6 +82,25 @@ def __init__(
self._agent_skills = skills
logger.info("Strands' integration with A2A is experimental. Be aware of frequent breaking changes.")
+ def _parse_public_url(self, url: str) -> tuple[str, str]:
+ """Parse the public URL into base URL and mount path components.
+
+ Args:
+ url: The full public URL (e.g., "http://my-alb.amazonaws.com/agent1")
+
+ Returns:
+ tuple: (base_url, mount_path) where base_url is the scheme+netloc
+ and mount_path is the path component
+
+ Example:
+ _parse_public_url("http://my-alb.amazonaws.com/agent1")
+ Returns: ("http://my-alb.amazonaws.com", "/agent1")
+ """
+ parsed = urlparse(url.rstrip("/"))
+ base_url = f"{parsed.scheme}://{parsed.netloc}"
+ mount_path = parsed.path if parsed.path != "/" else ""
+ return base_url, mount_path
+
@property
def public_agent_card(self) -> AgentCard:
"""Get the public AgentCard for this agent.
@@ -119,24 +162,42 @@ def agent_skills(self, skills: list[AgentSkill]) -> None:
def to_starlette_app(self) -> Starlette:
"""Create a Starlette application for serving this agent via HTTP.
- This method creates a Starlette application that can be used to serve
- the agent via HTTP using the A2A protocol.
+ Automatically handles path-based mounting if a mount path was derived
+ from the http_url parameter.
Returns:
Starlette: A Starlette application configured to serve this agent.
"""
- return A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build()
+ a2a_app = A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build()
+
+ if self.mount_path:
+ # Create parent app and mount the A2A app at the specified path
+ parent_app = Starlette()
+ parent_app.mount(self.mount_path, a2a_app)
+ logger.info("Mounting A2A server at path: %s", self.mount_path)
+ return parent_app
+
+ return a2a_app
def to_fastapi_app(self) -> FastAPI:
"""Create a FastAPI application for serving this agent via HTTP.
- This method creates a FastAPI application that can be used to serve
- the agent via HTTP using the A2A protocol.
+ Automatically handles path-based mounting if a mount path was derived
+ from the http_url parameter.
Returns:
FastAPI: A FastAPI application configured to serve this agent.
"""
- return A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build()
+ a2a_app = A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build()
+
+ if self.mount_path:
+ # Create parent app and mount the A2A app at the specified path
+ parent_app = FastAPI()
+ parent_app.mount(self.mount_path, a2a_app)
+ logger.info("Mounting A2A server at path: %s", self.mount_path)
+ return parent_app
+
+ return a2a_app
def serve(
self,
diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py
index 18a6ac474..75058b251 100644
--- a/src/strands/session/repository_session_manager.py
+++ b/src/strands/session/repository_session_manager.py
@@ -133,9 +133,7 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
agent.state = AgentState(session_agent.state)
# Restore the conversation manager to its previous state, and get the optional prepend messages
- prepend_messages = agent.conversation_manager.restore_from_session(
- session_agent.conversation_manager_state
- )
+ prepend_messages = agent.conversation_manager.restore_from_session(session_agent.conversation_manager_state)
if prepend_messages is None:
prepend_messages = []
diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py
index 74f470741..fc76b5f1d 100644
--- a/tests/strands/multiagent/a2a/test_server.py
+++ b/tests/strands/multiagent/a2a/test_server.py
@@ -509,3 +509,346 @@ def test_serve_handles_general_exception(mock_run, mock_strands_agent, caplog):
assert "Strands A2A server encountered exception" in caplog.text
assert "Strands A2A server has shutdown" in caplog.text
+
+
+def test_initialization_with_http_url_no_path(mock_strands_agent):
+ """Test initialization with http_url containing no path."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(
+ mock_strands_agent, host="0.0.0.0", port=8080, http_url="http://my-alb.amazonaws.com", skills=[]
+ )
+
+ assert a2a_agent.host == "0.0.0.0"
+ assert a2a_agent.port == 8080
+ assert a2a_agent.http_url == "http://my-alb.amazonaws.com/"
+ assert a2a_agent.public_base_url == "http://my-alb.amazonaws.com"
+ assert a2a_agent.mount_path == ""
+
+
+def test_initialization_with_http_url_with_path(mock_strands_agent):
+ """Test initialization with http_url containing a path for mounting."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(
+ mock_strands_agent, host="0.0.0.0", port=8080, http_url="http://my-alb.amazonaws.com/agent1", skills=[]
+ )
+
+ assert a2a_agent.host == "0.0.0.0"
+ assert a2a_agent.port == 8080
+ assert a2a_agent.http_url == "http://my-alb.amazonaws.com/agent1/"
+ assert a2a_agent.public_base_url == "http://my-alb.amazonaws.com"
+ assert a2a_agent.mount_path == "/agent1"
+
+
+def test_initialization_with_https_url(mock_strands_agent):
+ """Test initialization with HTTPS URL."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="https://my-alb.amazonaws.com/secure-agent", skills=[])
+
+ assert a2a_agent.http_url == "https://my-alb.amazonaws.com/secure-agent/"
+ assert a2a_agent.public_base_url == "https://my-alb.amazonaws.com"
+ assert a2a_agent.mount_path == "/secure-agent"
+
+
+def test_initialization_with_http_url_with_port(mock_strands_agent):
+ """Test initialization with http_url containing explicit port."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="http://my-server.com:8080/api/agent", skills=[])
+
+ assert a2a_agent.http_url == "http://my-server.com:8080/api/agent/"
+ assert a2a_agent.public_base_url == "http://my-server.com:8080"
+ assert a2a_agent.mount_path == "/api/agent"
+
+
+def test_parse_public_url_method(mock_strands_agent):
+ """Test the _parse_public_url method directly."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+ a2a_agent = A2AServer(mock_strands_agent, skills=[])
+
+ # Test various URL formats
+ base_url, mount_path = a2a_agent._parse_public_url("http://example.com/path")
+ assert base_url == "http://example.com"
+ assert mount_path == "/path"
+
+ base_url, mount_path = a2a_agent._parse_public_url("https://example.com:443/deep/path")
+ assert base_url == "https://example.com:443"
+ assert mount_path == "/deep/path"
+
+ base_url, mount_path = a2a_agent._parse_public_url("http://example.com/")
+ assert base_url == "http://example.com"
+ assert mount_path == ""
+
+ base_url, mount_path = a2a_agent._parse_public_url("http://example.com")
+ assert base_url == "http://example.com"
+ assert mount_path == ""
+
+
+def test_public_agent_card_with_http_url(mock_strands_agent):
+ """Test that public_agent_card uses the http_url when provided."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="https://my-alb.amazonaws.com/agent1", skills=[])
+
+ card = a2a_agent.public_agent_card
+
+ assert isinstance(card, AgentCard)
+ assert card.url == "https://my-alb.amazonaws.com/agent1/"
+ assert card.name == "Test Agent"
+ assert card.description == "A test agent for unit testing"
+
+
+def test_to_starlette_app_with_mounting(mock_strands_agent):
+ """Test that to_starlette_app creates mounted app when mount_path exists."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com/agent1", skills=[])
+
+ app = a2a_agent.to_starlette_app()
+
+ assert isinstance(app, Starlette)
+
+
+def test_to_starlette_app_without_mounting(mock_strands_agent):
+ """Test that to_starlette_app creates regular app when no mount_path."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com", skills=[])
+
+ app = a2a_agent.to_starlette_app()
+
+ assert isinstance(app, Starlette)
+
+
+def test_to_fastapi_app_with_mounting(mock_strands_agent):
+ """Test that to_fastapi_app creates mounted app when mount_path exists."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com/agent1", skills=[])
+
+ app = a2a_agent.to_fastapi_app()
+
+ assert isinstance(app, FastAPI)
+
+
+def test_to_fastapi_app_without_mounting(mock_strands_agent):
+ """Test that to_fastapi_app creates regular app when no mount_path."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com", skills=[])
+
+ app = a2a_agent.to_fastapi_app()
+
+ assert isinstance(app, FastAPI)
+
+
+def test_backwards_compatibility_without_http_url(mock_strands_agent):
+ """Test that the old behavior is preserved when http_url is not provided."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, host="localhost", port=9000, skills=[])
+
+ # Should behave exactly like before
+ assert a2a_agent.host == "localhost"
+ assert a2a_agent.port == 9000
+ assert a2a_agent.http_url == "http://localhost:9000/"
+ assert a2a_agent.public_base_url == "http://localhost:9000"
+ assert a2a_agent.mount_path == ""
+
+ # Agent card should use the traditional URL
+ card = a2a_agent.public_agent_card
+ assert card.url == "http://localhost:9000/"
+
+
+def test_mount_path_logging(mock_strands_agent, caplog):
+ """Test that mounting logs the correct message."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com/test-agent", skills=[])
+
+ # Test Starlette app mounting logs
+ caplog.clear()
+ a2a_agent.to_starlette_app()
+ assert "Mounting A2A server at path: /test-agent" in caplog.text
+
+ # Test FastAPI app mounting logs
+ caplog.clear()
+ a2a_agent.to_fastapi_app()
+ assert "Mounting A2A server at path: /test-agent" in caplog.text
+
+
+def test_http_url_trailing_slash_handling(mock_strands_agent):
+ """Test that trailing slashes in http_url are handled correctly."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ # Test with trailing slash
+ a2a_agent1 = A2AServer(mock_strands_agent, http_url="http://example.com/agent1/", skills=[])
+
+ # Test without trailing slash
+ a2a_agent2 = A2AServer(mock_strands_agent, http_url="http://example.com/agent1", skills=[])
+
+ # Both should result in the same normalized URL
+ assert a2a_agent1.http_url == "http://example.com/agent1/"
+ assert a2a_agent2.http_url == "http://example.com/agent1/"
+ assert a2a_agent1.mount_path == "/agent1"
+ assert a2a_agent2.mount_path == "/agent1"
+
+
+def test_serve_at_root_default_behavior(mock_strands_agent):
+ """Test default behavior extracts mount path from http_url."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ server = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", skills=[])
+
+ assert server.mount_path == "/agent1"
+ assert server.http_url == "http://my-alb.com/agent1/"
+
+
+def test_serve_at_root_overrides_mounting(mock_strands_agent):
+ """Test serve_at_root=True overrides automatic path mounting."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ server = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", serve_at_root=True, skills=[])
+
+ assert server.mount_path == "" # Should be empty despite path in URL
+ assert server.http_url == "http://my-alb.com/agent1/" # Public URL unchanged
+
+
+def test_serve_at_root_with_no_path(mock_strands_agent):
+ """Test serve_at_root=True when no path in URL (redundant but valid)."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ server = A2AServer(mock_strands_agent, host="localhost", port=8080, serve_at_root=True, skills=[])
+
+ assert server.mount_path == ""
+ assert server.http_url == "http://localhost:8080/"
+
+
+def test_serve_at_root_complex_path(mock_strands_agent):
+ """Test serve_at_root=True with complex nested paths."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ server = A2AServer(
+ mock_strands_agent, http_url="http://api.example.com/v1/agents/my-agent", serve_at_root=True, skills=[]
+ )
+
+ assert server.mount_path == ""
+ assert server.http_url == "http://api.example.com/v1/agents/my-agent/"
+
+
+def test_serve_at_root_fastapi_mounting_behavior(mock_strands_agent):
+ """Test FastAPI mounting behavior with serve_at_root."""
+ from fastapi.testclient import TestClient
+
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ # Normal mounting
+ server_mounted = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", skills=[])
+ app_mounted = server_mounted.to_fastapi_app()
+ client_mounted = TestClient(app_mounted)
+
+ # Should work at mounted path
+ response = client_mounted.get("/agent1/.well-known/agent.json")
+ assert response.status_code == 200
+
+ # Should not work at root
+ response = client_mounted.get("/.well-known/agent.json")
+ assert response.status_code == 404
+
+
+def test_serve_at_root_fastapi_root_behavior(mock_strands_agent):
+ """Test FastAPI serve_at_root behavior."""
+ from fastapi.testclient import TestClient
+
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ # Serve at root
+ server_root = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", serve_at_root=True, skills=[])
+ app_root = server_root.to_fastapi_app()
+ client_root = TestClient(app_root)
+
+ # Should work at root
+ response = client_root.get("/.well-known/agent.json")
+ assert response.status_code == 200
+
+ # Should not work at mounted path (since we're serving at root)
+ response = client_root.get("/agent1/.well-known/agent.json")
+ assert response.status_code == 404
+
+
+def test_serve_at_root_starlette_behavior(mock_strands_agent):
+ """Test Starlette serve_at_root behavior."""
+ from starlette.testclient import TestClient
+
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ # Normal mounting
+ server_mounted = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", skills=[])
+ app_mounted = server_mounted.to_starlette_app()
+ client_mounted = TestClient(app_mounted)
+
+ # Should work at mounted path
+ response = client_mounted.get("/agent1/.well-known/agent.json")
+ assert response.status_code == 200
+
+ # Serve at root
+ server_root = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", serve_at_root=True, skills=[])
+ app_root = server_root.to_starlette_app()
+ client_root = TestClient(app_root)
+
+ # Should work at root
+ response = client_root.get("/.well-known/agent.json")
+ assert response.status_code == 200
+
+
+def test_serve_at_root_alb_scenarios(mock_strands_agent):
+ """Test common ALB deployment scenarios."""
+ from fastapi.testclient import TestClient
+
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ # ALB with path preservation
+ server_preserved = A2AServer(mock_strands_agent, http_url="http://my-alb.amazonaws.com/agent1", skills=[])
+ app_preserved = server_preserved.to_fastapi_app()
+ client_preserved = TestClient(app_preserved)
+
+ # Container receives /agent1/.well-known/agent.json
+ response = client_preserved.get("/agent1/.well-known/agent.json")
+ assert response.status_code == 200
+ agent_data = response.json()
+ assert agent_data["url"] == "http://my-alb.amazonaws.com/agent1/"
+
+ # ALB with path stripping
+ server_stripped = A2AServer(
+ mock_strands_agent, http_url="http://my-alb.amazonaws.com/agent1", serve_at_root=True, skills=[]
+ )
+ app_stripped = server_stripped.to_fastapi_app()
+ client_stripped = TestClient(app_stripped)
+
+ # Container receives /.well-known/agent.json (path stripped by ALB)
+ response = client_stripped.get("/.well-known/agent.json")
+ assert response.status_code == 200
+ agent_data = response.json()
+ assert agent_data["url"] == "http://my-alb.amazonaws.com/agent1/"
+
+
+def test_serve_at_root_edge_cases(mock_strands_agent):
+ """Test edge cases for serve_at_root parameter."""
+ mock_strands_agent.tool_registry.get_all_tools_config.return_value = {}
+
+ # Root path in URL
+ server1 = A2AServer(mock_strands_agent, http_url="http://example.com/", skills=[])
+ assert server1.mount_path == ""
+
+ # serve_at_root should be redundant but not cause issues
+ server2 = A2AServer(mock_strands_agent, http_url="http://example.com/", serve_at_root=True, skills=[])
+ assert server2.mount_path == ""
+
+ # Multiple nested paths
+ server3 = A2AServer(
+ mock_strands_agent, http_url="http://api.example.com/v1/agents/team1/agent1", serve_at_root=True, skills=[]
+ )
+ assert server3.mount_path == ""
+ assert server3.http_url == "http://api.example.com/v1/agents/team1/agent1/"
From b30e7e6e41e7a2dce70d74e8c1753503959f3619 Mon Sep 17 00:00:00 2001
From: poshinchen
Date: Wed, 23 Jul 2025 15:20:28 -0400
Subject: [PATCH 010/104] fix: include agent trace into tool for agent as tools
(#526)
---
src/strands/telemetry/tracer.py | 2 +-
src/strands/tools/executor.py | 37 ++++++++++++++++-----------------
2 files changed, 19 insertions(+), 20 deletions(-)
diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py
index eebffef29..802865189 100644
--- a/src/strands/telemetry/tracer.py
+++ b/src/strands/telemetry/tracer.py
@@ -273,7 +273,7 @@ def end_model_invoke_span(
self._end_span(span, attributes, error)
- def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Optional[Span]:
+ def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Span:
"""Start a new span for a tool call.
Args:
diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py
index 1214fa608..d90f9a5aa 100644
--- a/src/strands/tools/executor.py
+++ b/src/strands/tools/executor.py
@@ -5,7 +5,7 @@
import time
from typing import Any, Optional, cast
-from opentelemetry import trace
+from opentelemetry import trace as trace_api
from ..telemetry.metrics import EventLoopMetrics, Trace
from ..telemetry.tracer import get_tracer
@@ -23,7 +23,7 @@ async def run_tools(
invalid_tool_use_ids: list[str],
tool_results: list[ToolResult],
cycle_trace: Trace,
- parent_span: Optional[trace.Span] = None,
+ parent_span: Optional[trace_api.Span] = None,
) -> ToolGenerator:
"""Execute tools concurrently.
@@ -53,24 +53,23 @@ async def work(
tool_name = tool_use["name"]
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
tool_start_time = time.time()
+ with trace_api.use_span(tool_call_span):
+ try:
+ async for event in handler(tool_use):
+ worker_queue.put_nowait((worker_id, event))
+ await worker_event.wait()
+ worker_event.clear()
+
+ result = cast(ToolResult, event)
+ finally:
+ worker_queue.put_nowait((worker_id, stop_event))
+
+ tool_success = result.get("status") == "success"
+ tool_duration = time.time() - tool_start_time
+ message = Message(role="user", content=[{"toolResult": result}])
+ event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message)
+ cycle_trace.add_child(tool_trace)
- try:
- async for event in handler(tool_use):
- worker_queue.put_nowait((worker_id, event))
- await worker_event.wait()
- worker_event.clear()
-
- result = cast(ToolResult, event)
- finally:
- worker_queue.put_nowait((worker_id, stop_event))
-
- tool_success = result.get("status") == "success"
- tool_duration = time.time() - tool_start_time
- message = Message(role="user", content=[{"toolResult": result}])
- event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message)
- cycle_trace.add_child(tool_trace)
-
- if tool_call_span:
tracer.end_tool_call_span(tool_call_span, result)
return result
From 8c5562575f8c6c26c2b2a18591d1d5926a96514a Mon Sep 17 00:00:00 2001
From: Davide Gallitelli
Date: Mon, 28 Jul 2025 13:34:04 +0200
Subject: [PATCH 011/104] Support for Amazon SageMaker AI endpoints as Model
Provider (#176)
---
pyproject.toml | 18 +-
src/strands/models/sagemaker.py | 600 +++++++++++++++++++++
tests/strands/models/test_sagemaker.py | 574 ++++++++++++++++++++
tests_integ/models/test_model_sagemaker.py | 76 +++
4 files changed, 1262 insertions(+), 6 deletions(-)
create mode 100644 src/strands/models/sagemaker.py
create mode 100644 tests/strands/models/test_sagemaker.py
create mode 100644 tests_integ/models/test_model_sagemaker.py
diff --git a/pyproject.toml b/pyproject.toml
index 765e815ef..745c80e0c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -89,8 +89,14 @@ writer = [
"writer-sdk>=2.2.0,<3.0.0"
]
+sagemaker = [
+ "boto3>=1.26.0,<2.0.0",
+ "botocore>=1.29.0,<2.0.0",
+ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0"
+]
+
a2a = [
- "a2a-sdk[sql]>=0.2.16,<1.0.0",
+ "a2a-sdk[sql]>=0.2.11,<1.0.0",
"uvicorn>=0.34.2,<1.0.0",
"httpx>=0.28.1,<1.0.0",
"fastapi>=0.115.12,<1.0.0",
@@ -136,7 +142,7 @@ all = [
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
# a2a
- "a2a-sdk[sql]>=0.2.16,<1.0.0",
+ "a2a-sdk[sql]>=0.2.11,<1.0.0",
"uvicorn>=0.34.2,<1.0.0",
"httpx>=0.28.1,<1.0.0",
"fastapi>=0.115.12,<1.0.0",
@@ -148,7 +154,7 @@ all = [
source = "vcs"
[tool.hatch.envs.hatch-static-analysis]
-features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"]
+features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"]
dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
@@ -171,7 +177,7 @@ lint-fix = [
]
[tool.hatch.envs.hatch-test]
-features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"]
+features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"]
extra-dependencies = [
"moto>=5.1.0,<6.0.0",
"pytest>=8.0.0,<9.0.0",
@@ -187,7 +193,7 @@ extra-args = [
[tool.hatch.envs.dev]
dev-mode = true
-features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a"]
+features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a", "sagemaker"]
[[tool.hatch.envs.hatch-test.matrix]]
python = ["3.13", "3.12", "3.11", "3.10"]
@@ -315,4 +321,4 @@ style = [
["instruction", ""],
["text", ""],
["disabled", "fg:#858585 italic"]
-]
+]
\ No newline at end of file
diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py
new file mode 100644
index 000000000..bb2db45a2
--- /dev/null
+++ b/src/strands/models/sagemaker.py
@@ -0,0 +1,600 @@
+"""Amazon SageMaker model provider."""
+
+import json
+import logging
+import os
+from dataclasses import dataclass
+from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast
+
+import boto3
+from botocore.config import Config as BotocoreConfig
+from mypy_boto3_sagemaker_runtime import SageMakerRuntimeClient
+from pydantic import BaseModel
+from typing_extensions import Unpack, override
+
+from ..types.content import ContentBlock, Messages
+from ..types.streaming import StreamEvent
+from ..types.tools import ToolResult, ToolSpec
+from .openai import OpenAIModel
+
+T = TypeVar("T", bound=BaseModel)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class UsageMetadata:
+ """Usage metadata for the model.
+
+ Attributes:
+ total_tokens: Total number of tokens used in the request
+ completion_tokens: Number of tokens used in the completion
+ prompt_tokens: Number of tokens used in the prompt
+ prompt_tokens_details: Additional information about the prompt tokens (optional)
+ """
+
+ total_tokens: int
+ completion_tokens: int
+ prompt_tokens: int
+ prompt_tokens_details: Optional[int] = 0
+
+
+@dataclass
+class FunctionCall:
+ """Function call for the model.
+
+ Attributes:
+ name: Name of the function to call
+ arguments: Arguments to pass to the function
+ """
+
+ name: Union[str, dict[Any, Any]]
+ arguments: Union[str, dict[Any, Any]]
+
+ def __init__(self, **kwargs: dict[str, str]):
+ """Initialize function call.
+
+ Args:
+ **kwargs: Keyword arguments for the function call.
+ """
+ self.name = kwargs.get("name", "")
+ self.arguments = kwargs.get("arguments", "")
+
+
+@dataclass
+class ToolCall:
+ """Tool call for the model object.
+
+ Attributes:
+ id: Tool call ID
+ type: Tool call type
+ function: Tool call function
+ """
+
+ id: str
+ type: Literal["function"]
+ function: FunctionCall
+
+ def __init__(self, **kwargs: dict):
+ """Initialize tool call object.
+
+ Args:
+ **kwargs: Keyword arguments for the tool call.
+ """
+ self.id = str(kwargs.get("id", ""))
+ self.type = "function"
+ self.function = FunctionCall(**kwargs.get("function", {"name": "", "arguments": ""}))
+
+
+class SageMakerAIModel(OpenAIModel):
+ """Amazon SageMaker model provider implementation."""
+
+ client: SageMakerRuntimeClient # type: ignore[assignment]
+
+ class SageMakerAIPayloadSchema(TypedDict, total=False):
+ """Payload schema for the Amazon SageMaker AI model.
+
+ Attributes:
+ max_tokens: Maximum number of tokens to generate in the completion
+ stream: Whether to stream the response
+ temperature: Sampling temperature to use for the model (optional)
+ top_p: Nucleus sampling parameter (optional)
+ top_k: Top-k sampling parameter (optional)
+ stop: List of stop sequences to use for the model (optional)
+ tool_results_as_user_messages: Convert tool result to user messages (optional)
+ additional_args: Additional request parameters, as supported by https://bit.ly/djl-lmi-request-schema
+ """
+
+ max_tokens: int
+ stream: bool
+ temperature: Optional[float]
+ top_p: Optional[float]
+ top_k: Optional[int]
+ stop: Optional[list[str]]
+ tool_results_as_user_messages: Optional[bool]
+ additional_args: Optional[dict[str, Any]]
+
+ class SageMakerAIEndpointConfig(TypedDict, total=False):
+ """Configuration options for SageMaker models.
+
+ Attributes:
+ endpoint_name: The name of the SageMaker endpoint to invoke
+ inference_component_name: The name of the inference component to use
+
+ additional_args: Other request parameters, as supported by https://bit.ly/sagemaker-invoke-endpoint-params
+ """
+
+ endpoint_name: str
+ region_name: str
+ inference_component_name: Union[str, None]
+ target_model: Union[Optional[str], None]
+ target_variant: Union[Optional[str], None]
+ additional_args: Optional[dict[str, Any]]
+
+ def __init__(
+ self,
+ endpoint_config: SageMakerAIEndpointConfig,
+ payload_config: SageMakerAIPayloadSchema,
+ boto_session: Optional[boto3.Session] = None,
+ boto_client_config: Optional[BotocoreConfig] = None,
+ ):
+ """Initialize provider instance.
+
+ Args:
+ endpoint_config: Endpoint configuration for SageMaker.
+ payload_config: Payload configuration for the model.
+ boto_session: Boto Session to use when calling the SageMaker Runtime.
+ boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client.
+ """
+ payload_config.setdefault("stream", True)
+ payload_config.setdefault("tool_results_as_user_messages", False)
+ self.endpoint_config = dict(endpoint_config)
+ self.payload_config = dict(payload_config)
+ logger.debug(
+ "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config
+ )
+
+ region = self.endpoint_config.get("region_name") or os.getenv("AWS_REGION") or "us-west-2"
+ session = boto_session or boto3.Session(region_name=str(region))
+
+ # Add strands-agents to the request user agent
+ if boto_client_config:
+ existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
+
+ # Append 'strands-agents' to existing user_agent_extra or set it if not present
+ new_user_agent = f"{existing_user_agent} strands-agents" if existing_user_agent else "strands-agents"
+
+ client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent))
+ else:
+ client_config = BotocoreConfig(user_agent_extra="strands-agents")
+
+ self.client = session.client(
+ service_name="sagemaker-runtime",
+ config=client_config,
+ )
+
+ @override
+ def update_config(self, **endpoint_config: Unpack[SageMakerAIEndpointConfig]) -> None: # type: ignore[override]
+ """Update the Amazon SageMaker model configuration with the provided arguments.
+
+ Args:
+ **endpoint_config: Configuration overrides.
+ """
+ self.endpoint_config.update(endpoint_config)
+
+ @override
+ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: ignore[override]
+ """Get the Amazon SageMaker model configuration.
+
+ Returns:
+ The Amazon SageMaker model configuration.
+ """
+ return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config)
+
+ @override
+ def format_request(
+ self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
+ ) -> dict[str, Any]:
+ """Format an Amazon SageMaker chat streaming request.
+
+ Args:
+ messages: List of message objects to be processed by the model.
+ tool_specs: List of tool specifications to make available to the model.
+ system_prompt: System prompt to provide context to the model.
+
+ Returns:
+ An Amazon SageMaker chat streaming request.
+ """
+ formatted_messages = self.format_request_messages(messages, system_prompt)
+
+ payload = {
+ "messages": formatted_messages,
+ "tools": [
+ {
+ "type": "function",
+ "function": {
+ "name": tool_spec["name"],
+ "description": tool_spec["description"],
+ "parameters": tool_spec["inputSchema"]["json"],
+ },
+ }
+ for tool_spec in tool_specs or []
+ ],
+ # Add payload configuration parameters
+ **{
+ k: v
+ for k, v in self.payload_config.items()
+ if k not in ["additional_args", "tool_results_as_user_messages"]
+ },
+ }
+
+ # Remove tools and tool_choice if tools = []
+ if not payload["tools"]:
+ payload.pop("tools")
+ payload.pop("tool_choice", None)
+ else:
+ # Ensure the model can use tools when available
+ payload["tool_choice"] = "auto"
+
+ for message in payload["messages"]: # type: ignore
+ # Assistant message must have either content or tool_calls, but not both
+ if message.get("role", "") == "assistant" and message.get("tool_calls", []) != []:
+ message.pop("content", None)
+ if message.get("role") == "tool" and self.payload_config.get("tool_results_as_user_messages", False):
+ # Convert tool message to user message
+ tool_call_id = message.get("tool_call_id", "ABCDEF")
+ content = message.get("content", "")
+ message = {"role": "user", "content": f"Tool call ID '{tool_call_id}' returned: {content}"}
+ # Cannot have both reasoning_text and text - if "text", content becomes an array of content["text"]
+ for c in message.get("content", []):
+ if "text" in c:
+ message["content"] = [c]
+ break
+ # Cast message content to string for TGI compatibility
+ # message["content"] = str(message.get("content", ""))
+
+ logger.info("payload=<%s>", json.dumps(payload, indent=2))
+ # Format the request according to the SageMaker Runtime API requirements
+ request = {
+ "EndpointName": self.endpoint_config["endpoint_name"],
+ "Body": json.dumps(payload),
+ "ContentType": "application/json",
+ "Accept": "application/json",
+ }
+
+ # Add optional SageMaker parameters if provided
+ if self.endpoint_config.get("inference_component_name"):
+ request["InferenceComponentName"] = self.endpoint_config["inference_component_name"]
+ if self.endpoint_config.get("target_model"):
+ request["TargetModel"] = self.endpoint_config["target_model"]
+ if self.endpoint_config.get("target_variant"):
+ request["TargetVariant"] = self.endpoint_config["target_variant"]
+
+ # Add additional args if provided
+ if self.endpoint_config.get("additional_args"):
+ request.update(self.endpoint_config["additional_args"].__dict__)
+
+ print(json.dumps(request["Body"], indent=2))
+
+ return request
+
+ @override
+ async def stream(
+ self,
+ messages: Messages,
+ tool_specs: Optional[list[ToolSpec]] = None,
+ system_prompt: Optional[str] = None,
+ **kwargs: Any,
+ ) -> AsyncGenerator[StreamEvent, None]:
+ """Stream conversation with the SageMaker model.
+
+ Args:
+ messages: List of message objects to be processed by the model.
+ tool_specs: List of tool specifications to make available to the model.
+ system_prompt: System prompt to provide context to the model.
+ **kwargs: Additional keyword arguments for future extensibility.
+
+ Yields:
+ Formatted message chunks from the model.
+ """
+ logger.debug("formatting request")
+ request = self.format_request(messages, tool_specs, system_prompt)
+ logger.debug("formatted request=<%s>", request)
+
+ logger.debug("invoking model")
+ try:
+ if self.payload_config.get("stream", True):
+ response = self.client.invoke_endpoint_with_response_stream(**request)
+
+ # Message start
+ yield self.format_chunk({"chunk_type": "message_start"})
+
+ # Parse the content
+ finish_reason = ""
+ partial_content = ""
+ tool_calls: dict[int, list[Any]] = {}
+ has_text_content = False
+ text_content_started = False
+ reasoning_content_started = False
+
+ for event in response["Body"]:
+ chunk = event["PayloadPart"]["Bytes"].decode("utf-8")
+ partial_content += chunk[6:] if chunk.startswith("data: ") else chunk # TGI fix
+ logger.info("chunk=<%s>", partial_content)
+ try:
+ content = json.loads(partial_content)
+ partial_content = ""
+ choice = content["choices"][0]
+ logger.info("choice=<%s>", json.dumps(choice, indent=2))
+
+ # Handle text content
+ if choice["delta"].get("content", None):
+ if not text_content_started:
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
+ text_content_started = True
+ has_text_content = True
+ yield self.format_chunk(
+ {
+ "chunk_type": "content_delta",
+ "data_type": "text",
+ "data": choice["delta"]["content"],
+ }
+ )
+
+ # Handle reasoning content
+ if choice["delta"].get("reasoning_content", None):
+ if not reasoning_content_started:
+ yield self.format_chunk(
+ {"chunk_type": "content_start", "data_type": "reasoning_content"}
+ )
+ reasoning_content_started = True
+ yield self.format_chunk(
+ {
+ "chunk_type": "content_delta",
+ "data_type": "reasoning_content",
+ "data": choice["delta"]["reasoning_content"],
+ }
+ )
+
+ # Handle tool calls
+ generated_tool_calls = choice["delta"].get("tool_calls", [])
+ if not isinstance(generated_tool_calls, list):
+ generated_tool_calls = [generated_tool_calls]
+ for tool_call in generated_tool_calls:
+ tool_calls.setdefault(tool_call["index"], []).append(tool_call)
+
+ if choice["finish_reason"] is not None:
+ finish_reason = choice["finish_reason"]
+ break
+
+ if choice.get("usage", None):
+ yield self.format_chunk(
+ {"chunk_type": "metadata", "data": UsageMetadata(**choice["usage"])}
+ )
+
+ except json.JSONDecodeError:
+ # Continue accumulating content until we have valid JSON
+ continue
+
+ # Close reasoning content if it was started
+ if reasoning_content_started:
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"})
+
+ # Close text content if it was started
+ if text_content_started:
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
+
+ # Handle tool calling
+ logger.info("tool_calls=<%s>", json.dumps(tool_calls, indent=2))
+ for tool_deltas in tool_calls.values():
+ if not tool_deltas[0]["function"].get("name", None):
+ raise Exception("The model did not provide a tool name.")
+ yield self.format_chunk(
+ {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_deltas[0])}
+ )
+ for tool_delta in tool_deltas:
+ yield self.format_chunk(
+ {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_delta)}
+ )
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
+
+ # If no content was generated at all, ensure we have empty text content
+ if not has_text_content and not tool_calls:
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
+
+ # Message close
+ yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason})
+
+ else:
+ # Not all SageMaker AI models support streaming!
+ response = self.client.invoke_endpoint(**request) # type: ignore[assignment]
+ final_response_json = json.loads(response["Body"].read().decode("utf-8")) # type: ignore[attr-defined]
+ logger.info("response=<%s>", json.dumps(final_response_json, indent=2))
+
+ # Obtain the key elements from the response
+ message = final_response_json["choices"][0]["message"]
+ message_stop_reason = final_response_json["choices"][0]["finish_reason"]
+
+ # Message start
+ yield self.format_chunk({"chunk_type": "message_start"})
+
+ # Handle text
+ if message.get("content", ""):
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
+ yield self.format_chunk(
+ {"chunk_type": "content_delta", "data_type": "text", "data": message["content"]}
+ )
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
+
+ # Handle reasoning content
+ if message.get("reasoning_content", None):
+ yield self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"})
+ yield self.format_chunk(
+ {
+ "chunk_type": "content_delta",
+ "data_type": "reasoning_content",
+ "data": message["reasoning_content"],
+ }
+ )
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"})
+
+ # Handle the tool calling, if any
+ if message.get("tool_calls", None) or message_stop_reason == "tool_calls":
+ if not isinstance(message["tool_calls"], list):
+ message["tool_calls"] = [message["tool_calls"]]
+ for tool_call in message["tool_calls"]:
+ # if arguments of tool_call is not str, cast it
+ if not isinstance(tool_call["function"]["arguments"], str):
+ tool_call["function"]["arguments"] = json.dumps(tool_call["function"]["arguments"])
+ yield self.format_chunk(
+ {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_call)}
+ )
+ yield self.format_chunk(
+ {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_call)}
+ )
+ yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
+ message_stop_reason = "tool_calls"
+
+ # Message close
+ yield self.format_chunk({"chunk_type": "message_stop", "data": message_stop_reason})
+ # Handle usage metadata
+ if final_response_json.get("usage", None):
+ yield self.format_chunk(
+ {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json.get("usage", None))}
+ )
+ except (
+ self.client.exceptions.InternalFailure,
+ self.client.exceptions.ServiceUnavailable,
+ self.client.exceptions.ValidationError,
+ self.client.exceptions.ModelError,
+ self.client.exceptions.InternalDependencyException,
+ self.client.exceptions.ModelNotReadyException,
+ ) as e:
+ logger.error("SageMaker error: %s", str(e))
+ raise e
+
+ logger.debug("finished streaming response from model")
+
+ @override
+ @classmethod
+ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]:
+ """Format a SageMaker compatible tool message.
+
+ Args:
+ tool_result: Tool result collected from a tool execution.
+
+ Returns:
+ SageMaker compatible tool message with content as a string.
+ """
+ # Convert content blocks to a simple string for SageMaker compatibility
+ content_parts = []
+ for content in tool_result["content"]:
+ if "json" in content:
+ content_parts.append(json.dumps(content["json"]))
+ elif "text" in content:
+ content_parts.append(content["text"])
+ else:
+ # Handle other content types by converting to string
+ content_parts.append(str(content))
+
+ content_string = " ".join(content_parts)
+
+ return {
+ "role": "tool",
+ "tool_call_id": tool_result["toolUseId"],
+ "content": content_string, # String instead of list
+ }
+
+ @override
+ @classmethod
+ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
+ """Format a content block.
+
+ Args:
+ content: Message content.
+
+ Returns:
+ Formatted content block.
+
+ Raises:
+ TypeError: If the content block type cannot be converted to a SageMaker-compatible format.
+ """
+ # if "text" in content and not isinstance(content["text"], str):
+ # return {"type": "text", "text": str(content["text"])}
+
+ if "reasoningContent" in content and content["reasoningContent"]:
+ return {
+ "signature": content["reasoningContent"].get("reasoningText", {}).get("signature", ""),
+ "thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""),
+ "type": "thinking",
+ }
+ elif not content.get("reasoningContent", None):
+ content.pop("reasoningContent", None)
+
+ if "video" in content:
+ return {
+ "type": "video_url",
+ "video_url": {
+ "detail": "auto",
+ "url": content["video"]["source"]["bytes"],
+ },
+ }
+
+ return super().format_request_message_content(content)
+
+ @override
+ async def structured_output(
+ self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
+ ) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
+ """Get structured output from the model.
+
+ Args:
+ output_model: The output model to use for the agent.
+ prompt: The prompt messages to use for the agent.
+ system_prompt: System prompt to provide context to the model.
+ **kwargs: Additional keyword arguments for future extensibility.
+
+ Yields:
+ Model events with the last being the structured output.
+ """
+ # Format the request for structured output
+ request = self.format_request(prompt, system_prompt=system_prompt)
+
+ # Parse the payload to add response format
+ payload = json.loads(request["Body"])
+ payload["response_format"] = {
+ "type": "json_schema",
+ "json_schema": {"name": output_model.__name__, "schema": output_model.model_json_schema(), "strict": True},
+ }
+ request["Body"] = json.dumps(payload)
+
+ try:
+ # Use non-streaming mode for structured output
+ response = self.client.invoke_endpoint(**request)
+ final_response_json = json.loads(response["Body"].read().decode("utf-8"))
+
+ # Extract the structured content
+ message = final_response_json["choices"][0]["message"]
+
+ if message.get("content"):
+ try:
+ # Parse the JSON content and create the output model instance
+ content_data = json.loads(message["content"])
+ parsed_output = output_model(**content_data)
+ yield {"output": parsed_output}
+ except (json.JSONDecodeError, TypeError, ValueError) as e:
+ raise ValueError(f"Failed to parse structured output: {e}") from e
+ else:
+ raise ValueError("No content found in SageMaker response")
+
+ except (
+ self.client.exceptions.InternalFailure,
+ self.client.exceptions.ServiceUnavailable,
+ self.client.exceptions.ValidationError,
+ self.client.exceptions.ModelError,
+ self.client.exceptions.InternalDependencyException,
+ self.client.exceptions.ModelNotReadyException,
+ ) as e:
+ logger.error("SageMaker structured output error: %s", str(e))
+ raise ValueError(f"SageMaker structured output error: {str(e)}") from e
diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py
new file mode 100644
index 000000000..ba395b2d6
--- /dev/null
+++ b/tests/strands/models/test_sagemaker.py
@@ -0,0 +1,574 @@
+"""Tests for the Amazon SageMaker model provider."""
+
+import json
+import unittest.mock
+from typing import Any, Dict, List
+
+import boto3
+import pytest
+from botocore.config import Config as BotocoreConfig
+
+from strands.models.sagemaker import (
+ FunctionCall,
+ SageMakerAIModel,
+ ToolCall,
+ UsageMetadata,
+)
+from strands.types.content import Messages
+from strands.types.tools import ToolSpec
+
+
+@pytest.fixture
+def boto_session():
+ """Mock boto3 session."""
+ with unittest.mock.patch.object(boto3, "Session") as mock_session:
+ yield mock_session.return_value
+
+
+@pytest.fixture
+def sagemaker_client(boto_session):
+ """Mock SageMaker runtime client."""
+ return boto_session.client.return_value
+
+
+@pytest.fixture
+def endpoint_config() -> Dict[str, Any]:
+ """Default endpoint configuration for tests."""
+ return {
+ "endpoint_name": "test-endpoint",
+ "inference_component_name": "test-component",
+ "region_name": "us-east-1",
+ }
+
+
+@pytest.fixture
+def payload_config() -> Dict[str, Any]:
+ """Default payload configuration for tests."""
+ return {
+ "max_tokens": 1024,
+ "temperature": 0.7,
+ "stream": True,
+ }
+
+
+@pytest.fixture
+def model(boto_session, endpoint_config, payload_config):
+ """SageMaker model instance with mocked boto session."""
+ return SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session)
+
+
+@pytest.fixture
+def messages() -> Messages:
+ """Sample messages for testing."""
+ return [{"role": "user", "content": [{"text": "What is the capital of France?"}]}]
+
+
+@pytest.fixture
+def tool_specs() -> List[ToolSpec]:
+ """Sample tool specifications for testing."""
+ return [
+ {
+ "name": "get_weather",
+ "description": "Get the weather for a location",
+ "inputSchema": {
+ "json": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ "required": ["location"],
+ }
+ },
+ }
+ ]
+
+
+@pytest.fixture
+def system_prompt() -> str:
+ """Sample system prompt for testing."""
+ return "You are a helpful assistant."
+
+
+class TestSageMakerAIModel:
+ """Test suite for SageMakerAIModel."""
+
+ def test_init_default(self, boto_session):
+ """Test initialization with default parameters."""
+ endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"}
+ payload_config = {"max_tokens": 1024}
+ model = SageMakerAIModel(
+ endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session
+ )
+
+ assert model.endpoint_config["endpoint_name"] == "test-endpoint"
+ assert model.payload_config.get("stream", True) is True
+
+ boto_session.client.assert_called_once_with(
+ service_name="sagemaker-runtime",
+ config=unittest.mock.ANY,
+ )
+
+ def test_init_with_all_params(self, boto_session):
+ """Test initialization with all parameters."""
+ endpoint_config = {
+ "endpoint_name": "test-endpoint",
+ "inference_component_name": "test-component",
+ "region_name": "us-west-2",
+ }
+ payload_config = {
+ "stream": False,
+ "max_tokens": 1024,
+ "temperature": 0.7,
+ }
+ client_config = BotocoreConfig(user_agent_extra="test-agent")
+
+ model = SageMakerAIModel(
+ endpoint_config=endpoint_config,
+ payload_config=payload_config,
+ boto_session=boto_session,
+ boto_client_config=client_config,
+ )
+
+ assert model.endpoint_config["endpoint_name"] == "test-endpoint"
+ assert model.endpoint_config["inference_component_name"] == "test-component"
+ assert model.payload_config["stream"] is False
+ assert model.payload_config["max_tokens"] == 1024
+ assert model.payload_config["temperature"] == 0.7
+
+ boto_session.client.assert_called_once_with(
+ service_name="sagemaker-runtime",
+ config=unittest.mock.ANY,
+ )
+
+ def test_init_with_client_config(self, boto_session):
+ """Test initialization with client configuration."""
+ endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"}
+ payload_config = {"max_tokens": 1024}
+ client_config = BotocoreConfig(user_agent_extra="test-agent")
+
+ SageMakerAIModel(
+ endpoint_config=endpoint_config,
+ payload_config=payload_config,
+ boto_session=boto_session,
+ boto_client_config=client_config,
+ )
+
+ # Verify client was created with a config that includes our user agent
+ boto_session.client.assert_called_once_with(
+ service_name="sagemaker-runtime",
+ config=unittest.mock.ANY,
+ )
+
+ # Get the actual config passed to client
+ actual_config = boto_session.client.call_args[1]["config"]
+ assert "strands-agents" in actual_config.user_agent_extra
+ assert "test-agent" in actual_config.user_agent_extra
+
+ def test_update_config(self, model):
+ """Test updating model configuration."""
+ new_config = {"target_model": "new-model", "target_variant": "new-variant"}
+ model.update_config(**new_config)
+
+ assert model.endpoint_config["target_model"] == "new-model"
+ assert model.endpoint_config["target_variant"] == "new-variant"
+ # Original values should be preserved
+ assert model.endpoint_config["endpoint_name"] == "test-endpoint"
+ assert model.endpoint_config["inference_component_name"] == "test-component"
+
+ def test_get_config(self, model, endpoint_config):
+ """Test getting model configuration."""
+ config = model.get_config()
+ assert config == model.endpoint_config
+ assert isinstance(config, dict)
+
+ # def test_format_request_messages_with_system_prompt(self, model):
+ # """Test formatting request messages with system prompt."""
+ # messages = [{"role": "user", "content": "Hello"}]
+ # system_prompt = "You are a helpful assistant."
+
+ # formatted_messages = model.format_request_messages(messages, system_prompt)
+
+ # assert len(formatted_messages) == 2
+ # assert formatted_messages[0]["role"] == "system"
+ # assert formatted_messages[0]["content"] == system_prompt
+ # assert formatted_messages[1]["role"] == "user"
+ # assert formatted_messages[1]["content"] == "Hello"
+
+ # def test_format_request_messages_with_tool_calls(self, model):
+ # """Test formatting request messages with tool calls."""
+ # messages = [
+ # {"role": "user", "content": "Hello"},
+ # {
+ # "role": "assistant",
+ # "content": None,
+ # "tool_calls": [{"id": "123", "type": "function", "function": {"name": "test", "arguments": "{}"}}],
+ # },
+ # ]
+
+ # formatted_messages = model.format_request_messages(messages, None)
+
+ # assert len(formatted_messages) == 2
+ # assert formatted_messages[0]["role"] == "user"
+ # assert formatted_messages[1]["role"] == "assistant"
+ # assert "content" not in formatted_messages[1]
+ # assert "tool_calls" in formatted_messages[1]
+
+ # def test_format_request(self, model, messages, tool_specs, system_prompt):
+ # """Test formatting a request with all parameters."""
+ # request = model.format_request(messages, tool_specs, system_prompt)
+
+ # assert request["EndpointName"] == "test-endpoint"
+ # assert request["InferenceComponentName"] == "test-component"
+ # assert request["ContentType"] == "application/json"
+ # assert request["Accept"] == "application/json"
+
+ # payload = json.loads(request["Body"])
+ # assert "messages" in payload
+ # assert len(payload["messages"]) > 0
+ # assert "tools" in payload
+ # assert len(payload["tools"]) == 1
+ # assert payload["tools"][0]["type"] == "function"
+ # assert payload["tools"][0]["function"]["name"] == "get_weather"
+ # assert payload["max_tokens"] == 1024
+ # assert payload["temperature"] == 0.7
+ # assert payload["stream"] is True
+
+ # def test_format_request_without_tools(self, model, messages, system_prompt):
+ # """Test formatting a request without tools."""
+ # request = model.format_request(messages, None, system_prompt)
+
+ # payload = json.loads(request["Body"])
+ # assert "tools" in payload
+ # assert payload["tools"] == []
+
+ @pytest.mark.asyncio
+ async def test_stream_with_streaming_enabled(self, sagemaker_client, model, messages):
+ """Test streaming response with streaming enabled."""
+ # Mock the response from SageMaker
+ mock_response = {
+ "Body": [
+ {
+ "PayloadPart": {
+ "Bytes": json.dumps(
+ {
+ "choices": [
+ {
+ "delta": {"content": "Paris is the capital of France."},
+ "finish_reason": None,
+ }
+ ]
+ }
+ ).encode("utf-8")
+ }
+ },
+ {
+ "PayloadPart": {
+ "Bytes": json.dumps(
+ {
+ "choices": [
+ {
+ "delta": {"content": " It is known for the Eiffel Tower."},
+ "finish_reason": "stop",
+ }
+ ]
+ }
+ ).encode("utf-8")
+ }
+ },
+ ]
+ }
+ sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response
+
+ response = [chunk async for chunk in model.stream(messages)]
+
+ assert len(response) >= 5
+ assert response[0] == {"messageStart": {"role": "assistant"}}
+
+ # Find content events
+ content_start = next((e for e in response if "contentBlockStart" in e), None)
+ content_delta = next((e for e in response if "contentBlockDelta" in e), None)
+ content_stop = next((e for e in response if "contentBlockStop" in e), None)
+ message_stop = next((e for e in response if "messageStop" in e), None)
+
+ assert content_start is not None
+ assert content_delta is not None
+ assert content_stop is not None
+ assert message_stop is not None
+ assert message_stop["messageStop"]["stopReason"] == "end_turn"
+
+ sagemaker_client.invoke_endpoint_with_response_stream.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_stream_with_tool_calls(self, sagemaker_client, model, messages):
+ """Test streaming response with tool calls."""
+ # Mock the response from SageMaker with tool calls
+ mock_response = {
+ "Body": [
+ {
+ "PayloadPart": {
+ "Bytes": json.dumps(
+ {
+ "choices": [
+ {
+ "delta": {
+ "content": None,
+ "tool_calls": [
+ {
+ "index": 0,
+ "id": "tool123",
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "arguments": '{"location": "Paris"}',
+ },
+ }
+ ],
+ },
+ "finish_reason": "tool_calls",
+ }
+ ]
+ }
+ ).encode("utf-8")
+ }
+ }
+ ]
+ }
+ sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response
+
+ response = [chunk async for chunk in model.stream(messages)]
+
+ # Verify the response contains tool call events
+ assert len(response) >= 4
+ assert response[0] == {"messageStart": {"role": "assistant"}}
+
+ message_stop = next((e for e in response if "messageStop" in e), None)
+ assert message_stop is not None
+ assert message_stop["messageStop"]["stopReason"] == "tool_use"
+
+ # Find tool call events
+ tool_start = next(
+ (
+ e
+ for e in response
+ if "contentBlockStart" in e and e.get("contentBlockStart", {}).get("start", {}).get("toolUse")
+ ),
+ None,
+ )
+ tool_delta = next(
+ (
+ e
+ for e in response
+ if "contentBlockDelta" in e and e.get("contentBlockDelta", {}).get("delta", {}).get("toolUse")
+ ),
+ None,
+ )
+ tool_stop = next((e for e in response if "contentBlockStop" in e), None)
+
+ assert tool_start is not None
+ assert tool_delta is not None
+ assert tool_stop is not None
+
+ # Verify tool call data
+ tool_use_data = tool_start["contentBlockStart"]["start"]["toolUse"]
+ assert tool_use_data["toolUseId"] == "tool123"
+ assert tool_use_data["name"] == "get_weather"
+
+ @pytest.mark.asyncio
+ async def test_stream_with_partial_json(self, sagemaker_client, model, messages):
+ """Test streaming response with partial JSON chunks."""
+ # Mock the response from SageMaker with split JSON
+ mock_response = {
+ "Body": [
+ {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}},
+ {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}},
+ ]
+ }
+ sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response
+
+ response = [chunk async for chunk in model.stream(messages)]
+
+ assert len(response) == 5
+ assert response[0] == {"messageStart": {"role": "assistant"}}
+
+ # Find content events
+ content_start = next((e for e in response if "contentBlockStart" in e), None)
+ content_delta = next((e for e in response if "contentBlockDelta" in e), None)
+ content_stop = next((e for e in response if "contentBlockStop" in e), None)
+ message_stop = next((e for e in response if "messageStop" in e), None)
+
+ assert content_start is not None
+ assert content_delta is not None
+ assert content_stop is not None
+ assert message_stop is not None
+ assert message_stop["messageStop"]["stopReason"] == "end_turn"
+
+ # Verify content
+ text_delta = content_delta["contentBlockDelta"]["delta"]["text"]
+ assert text_delta == "Paris is the capital of France."
+
+ @pytest.mark.asyncio
+ async def test_stream_non_streaming(self, sagemaker_client, model, messages):
+ """Test non-streaming response."""
+ # Configure model for non-streaming
+ model.payload_config["stream"] = False
+
+ # Mock the response from SageMaker
+ mock_response = {"Body": unittest.mock.MagicMock()}
+ mock_response["Body"].read.return_value = json.dumps(
+ {
+ "choices": [
+ {
+ "message": {"content": "Paris is the capital of France.", "tool_calls": None},
+ "finish_reason": "stop",
+ }
+ ],
+ "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "prompt_tokens_details": 0},
+ }
+ ).encode("utf-8")
+
+ sagemaker_client.invoke_endpoint.return_value = mock_response
+
+ response = [chunk async for chunk in model.stream(messages)]
+
+ assert len(response) >= 6
+ assert response[0] == {"messageStart": {"role": "assistant"}}
+
+ # Find content events
+ content_start = next((e for e in response if "contentBlockStart" in e), None)
+ content_delta = next((e for e in response if "contentBlockDelta" in e), None)
+ content_stop = next((e for e in response if "contentBlockStop" in e), None)
+ message_stop = next((e for e in response if "messageStop" in e), None)
+
+ assert content_start is not None
+ assert content_delta is not None
+ assert content_stop is not None
+ assert message_stop is not None
+
+ # Verify content
+ text_delta = content_delta["contentBlockDelta"]["delta"]["text"]
+ assert text_delta == "Paris is the capital of France."
+
+ sagemaker_client.invoke_endpoint.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_stream_non_streaming_with_tool_calls(self, sagemaker_client, model, messages):
+ """Test non-streaming response with tool calls."""
+ # Configure model for non-streaming
+ model.payload_config["stream"] = False
+
+ # Mock the response from SageMaker with tool calls
+ mock_response = {"Body": unittest.mock.MagicMock()}
+ mock_response["Body"].read.return_value = json.dumps(
+ {
+ "choices": [
+ {
+ "message": {
+ "content": None,
+ "tool_calls": [
+ {
+ "id": "tool123",
+ "type": "function",
+ "function": {"name": "get_weather", "arguments": '{"location": "Paris"}'},
+ }
+ ],
+ },
+ "finish_reason": "tool_calls",
+ }
+ ],
+ "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "prompt_tokens_details": 0},
+ }
+ ).encode("utf-8")
+
+ sagemaker_client.invoke_endpoint.return_value = mock_response
+
+ response = [chunk async for chunk in model.stream(messages)]
+
+ # Verify basic structure
+ assert len(response) >= 6
+ assert response[0] == {"messageStart": {"role": "assistant"}}
+
+ # Find tool call events
+ tool_start = next(
+ (
+ e
+ for e in response
+ if "contentBlockStart" in e and e.get("contentBlockStart", {}).get("start", {}).get("toolUse")
+ ),
+ None,
+ )
+ tool_delta = next(
+ (
+ e
+ for e in response
+ if "contentBlockDelta" in e and e.get("contentBlockDelta", {}).get("delta", {}).get("toolUse")
+ ),
+ None,
+ )
+ tool_stop = next((e for e in response if "contentBlockStop" in e), None)
+ message_stop = next((e for e in response if "messageStop" in e), None)
+
+ assert tool_start is not None
+ assert tool_delta is not None
+ assert tool_stop is not None
+ assert message_stop is not None
+
+ # Verify tool call data
+ tool_use_data = tool_start["contentBlockStart"]["start"]["toolUse"]
+ assert tool_use_data["toolUseId"] == "tool123"
+ assert tool_use_data["name"] == "get_weather"
+
+ # Verify metadata
+ metadata = next((e for e in response if "metadata" in e), None)
+ assert metadata is not None
+ usage_data = metadata["metadata"]["usage"]
+ assert usage_data["totalTokens"] == 30
+
+
+class TestDataClasses:
+ """Test suite for data classes."""
+
+ def test_usage_metadata(self):
+ """Test UsageMetadata dataclass."""
+ usage = UsageMetadata(total_tokens=100, completion_tokens=30, prompt_tokens=70, prompt_tokens_details=5)
+
+ assert usage.total_tokens == 100
+ assert usage.completion_tokens == 30
+ assert usage.prompt_tokens == 70
+ assert usage.prompt_tokens_details == 5
+
+ def test_function_call(self):
+ """Test FunctionCall dataclass."""
+ func = FunctionCall(name="get_weather", arguments='{"location": "Paris"}')
+
+ assert func.name == "get_weather"
+ assert func.arguments == '{"location": "Paris"}'
+
+ # Test initialization with kwargs
+ func2 = FunctionCall(**{"name": "get_time", "arguments": '{"timezone": "UTC"}'})
+
+ assert func2.name == "get_time"
+ assert func2.arguments == '{"timezone": "UTC"}'
+
+ def test_tool_call(self):
+ """Test ToolCall dataclass."""
+ # Create a tool call using kwargs directly
+ tool = ToolCall(
+ id="tool123", type="function", function={"name": "get_weather", "arguments": '{"location": "Paris"}'}
+ )
+
+ assert tool.id == "tool123"
+ assert tool.type == "function"
+ assert tool.function.name == "get_weather"
+ assert tool.function.arguments == '{"location": "Paris"}'
+
+ # Test initialization with kwargs
+ tool2 = ToolCall(
+ **{
+ "id": "tool456",
+ "type": "function",
+ "function": {"name": "get_time", "arguments": '{"timezone": "UTC"}'},
+ }
+ )
+
+ assert tool2.id == "tool456"
+ assert tool2.type == "function"
+ assert tool2.function.name == "get_time"
+ assert tool2.function.arguments == '{"timezone": "UTC"}'
diff --git a/tests_integ/models/test_model_sagemaker.py b/tests_integ/models/test_model_sagemaker.py
new file mode 100644
index 000000000..62362e299
--- /dev/null
+++ b/tests_integ/models/test_model_sagemaker.py
@@ -0,0 +1,76 @@
+import os
+
+import pytest
+
+import strands
+from strands import Agent
+from strands.models.sagemaker import SageMakerAIModel
+
+
+@pytest.fixture
+def model():
+ endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig(
+ endpoint_name=os.getenv("SAGEMAKER_ENDPOINT_NAME", ""), region_name="us-east-1"
+ )
+ payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024, temperature=0.7, stream=False)
+ return SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config)
+
+
+@pytest.fixture
+def tools():
+ @strands.tool
+ def tool_time(location: str) -> str:
+ """Get the current time for a location."""
+ return f"The time in {location} is 12:00 PM"
+
+ @strands.tool
+ def tool_weather(location: str) -> str:
+ """Get the current weather for a location."""
+ return f"The weather in {location} is sunny"
+
+ return [tool_time, tool_weather]
+
+
+@pytest.fixture
+def system_prompt():
+ return "You are a helpful assistant that provides concise answers."
+
+
+@pytest.fixture
+def agent(model, tools, system_prompt):
+ return Agent(model=model, tools=tools, system_prompt=system_prompt)
+
+
+@pytest.mark.skipif(
+ "SAGEMAKER_ENDPOINT_NAME" not in os.environ,
+ reason="SAGEMAKER_ENDPOINT_NAME environment variable missing",
+)
+def test_agent_with_tools(agent):
+ result = agent("What is the time and weather in New York?")
+ text = result.message["content"][0]["text"].lower()
+
+ assert "12:00" in text and "sunny" in text
+
+
+@pytest.mark.skipif(
+ "SAGEMAKER_ENDPOINT_NAME" not in os.environ,
+ reason="SAGEMAKER_ENDPOINT_NAME environment variable missing",
+)
+def test_agent_without_tools(model, system_prompt):
+ agent = Agent(model=model, system_prompt=system_prompt)
+ result = agent("Hello, how are you?")
+
+ assert result.message["content"][0]["text"]
+ assert len(result.message["content"][0]["text"]) > 0
+
+
+@pytest.mark.skipif(
+ "SAGEMAKER_ENDPOINT_NAME" not in os.environ,
+ reason="SAGEMAKER_ENDPOINT_NAME environment variable missing",
+)
+@pytest.mark.parametrize("location", ["Tokyo", "London", "Sydney"])
+def test_agent_different_locations(agent, location):
+ result = agent(f"What is the weather in {location}?")
+ text = result.message["content"][0]["text"].lower()
+
+ assert location.lower() in text and "sunny" in text
From 3f4c3a35ce14800e4852998e0c2b68f90295ffb7 Mon Sep 17 00:00:00 2001
From: mehtarac
Date: Mon, 28 Jul 2025 10:23:43 -0400
Subject: [PATCH 012/104] fix: Remove leftover print statement from sagemaker
model provider (#553)
---
src/strands/models/sagemaker.py | 2 --
1 file changed, 2 deletions(-)
diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py
index bb2db45a2..9cfe27d9e 100644
--- a/src/strands/models/sagemaker.py
+++ b/src/strands/models/sagemaker.py
@@ -274,8 +274,6 @@ def format_request(
if self.endpoint_config.get("additional_args"):
request.update(self.endpoint_config["additional_args"].__dict__)
- print(json.dumps(request["Body"], indent=2))
-
return request
@override
From bdc893bbae711c1af301e6f18901cb30814789a0 Mon Sep 17 00:00:00 2001
From: Nick Clegg
Date: Tue, 29 Jul 2025 14:41:57 -0400
Subject: [PATCH 013/104] [Feat] Update structured output error message (#563)
* Update bedrock.py
* Update anthropic.py
---
src/strands/models/anthropic.py | 2 +-
src/strands/models/bedrock.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py
index eb72becfd..0d734b762 100644
--- a/src/strands/models/anthropic.py
+++ b/src/strands/models/anthropic.py
@@ -414,7 +414,7 @@ async def structured_output(
stop_reason, messages, _, _ = event["stop"]
if stop_reason != "tool_use":
- raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")
+ raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".")
content = messages["content"]
output_response: dict[str, Any] | None = None
diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py
index 679f1ea3d..cf1e4d3a9 100644
--- a/src/strands/models/bedrock.py
+++ b/src/strands/models/bedrock.py
@@ -584,7 +584,7 @@ async def structured_output(
stop_reason, messages, _, _ = event["stop"]
if stop_reason != "tool_use":
- raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
+ raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".")
content = messages["content"]
output_response: dict[str, Any] | None = None
From 4e0e0a648c7e441ce15eacca213b7b65e982fd3b Mon Sep 17 00:00:00 2001
From: Dean Schmigelski
Date: Tue, 29 Jul 2025 18:03:19 -0400
Subject: [PATCH 014/104] feat(mcp): retain structured content in the AgentTool
response (#528)
---
pyproject.toml | 2 +-
src/strands/models/bedrock.py | 53 +++++++++-
src/strands/tools/mcp/mcp_client.py | 49 +++++++---
src/strands/tools/mcp/mcp_types.py | 20 ++++
tests/strands/models/test_bedrock.py | 96 ++++++++++++-------
tests/strands/tools/mcp/test_mcp_client.py | 67 +++++++++++++
tests_integ/echo_server.py | 16 +++-
tests_integ/test_mcp_client.py | 77 +++++++++++++++
...cp_client_structured_content_with_hooks.py | 65 +++++++++++++
9 files changed, 389 insertions(+), 56 deletions(-)
create mode 100644 tests_integ/test_mcp_client_structured_content_with_hooks.py
diff --git a/pyproject.toml b/pyproject.toml
index 745c80e0c..095a38cb0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,7 +29,7 @@ dependencies = [
"boto3>=1.26.0,<2.0.0",
"botocore>=1.29.0,<2.0.0",
"docstring_parser>=0.15,<1.0",
- "mcp>=1.8.0,<2.0.0",
+ "mcp>=1.11.0,<2.0.0",
"pydantic>=2.0.0,<3.0.0",
"typing-extensions>=4.13.2,<5.0.0",
"watchdog>=6.0.0,<7.0.0",
diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py
index cf1e4d3a9..9b36b4244 100644
--- a/src/strands/models/bedrock.py
+++ b/src/strands/models/bedrock.py
@@ -17,10 +17,10 @@
from ..event_loop import streaming
from ..tools import convert_pydantic_to_tool_spec
-from ..types.content import Messages
+from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
-from ..types.tools import ToolSpec
+from ..types.tools import ToolResult, ToolSpec
from .model import Model
logger = logging.getLogger(__name__)
@@ -181,7 +181,7 @@ def format_request(
"""
return {
"modelId": self.config["model_id"],
- "messages": messages,
+ "messages": self._format_bedrock_messages(messages),
"system": [
*([{"text": system_prompt}] if system_prompt else []),
*([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []),
@@ -246,6 +246,53 @@ def format_request(
),
}
+ def _format_bedrock_messages(self, messages: Messages) -> Messages:
+ """Format messages for Bedrock API compatibility.
+
+ This function ensures messages conform to Bedrock's expected format by:
+ - Cleaning tool result content blocks by removing additional fields that may be
+ useful for retaining information in hooks but would cause Bedrock validation
+ exceptions when presented with unexpected fields
+ - Ensuring all message content blocks are properly formatted for the Bedrock API
+
+ Args:
+ messages: List of messages to format
+
+ Returns:
+ Messages formatted for Bedrock API compatibility
+
+ Note:
+ Bedrock will throw validation exceptions when presented with additional
+ unexpected fields in tool result blocks.
+ https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
+ """
+ cleaned_messages = []
+
+ for message in messages:
+ cleaned_content: list[ContentBlock] = []
+
+ for content_block in message["content"]:
+ if "toolResult" in content_block:
+ # Create a new content block with only the cleaned toolResult
+ tool_result: ToolResult = content_block["toolResult"]
+
+ # Keep only the required fields for Bedrock
+ cleaned_tool_result = ToolResult(
+ content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"]
+ )
+
+ cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result}
+ cleaned_content.append(cleaned_block)
+ else:
+ # Keep other content blocks as-is
+ cleaned_content.append(content_block)
+
+ # Create new message with cleaned content
+ cleaned_message: Message = Message(content=cleaned_content, role=message["role"])
+ cleaned_messages.append(cleaned_message)
+
+ return cleaned_messages
+
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
"""Check if guardrail data contains any blocked policies.
diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py
index 4cf4e1f85..784636fd0 100644
--- a/src/strands/tools/mcp/mcp_client.py
+++ b/src/strands/tools/mcp/mcp_client.py
@@ -26,9 +26,9 @@
from ...types import PaginatedList
from ...types.exceptions import MCPClientInitializationError
from ...types.media import ImageFormat
-from ...types.tools import ToolResult, ToolResultContent, ToolResultStatus
+from ...types.tools import ToolResultContent, ToolResultStatus
from .mcp_agent_tool import MCPAgentTool
-from .mcp_types import MCPTransport
+from .mcp_types import MCPToolResult, MCPTransport
logger = logging.getLogger(__name__)
@@ -57,7 +57,8 @@ class MCPClient:
It handles the creation, initialization, and cleanup of MCP connections.
The connection runs in a background thread to avoid blocking the main application thread
- while maintaining communication with the MCP service.
+ while maintaining communication with the MCP service. When structured content is available
+ from MCP tools, it will be returned as the last item in the content array of the ToolResult.
"""
def __init__(self, transport_callable: Callable[[], MCPTransport]):
@@ -170,11 +171,13 @@ def call_tool_sync(
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
- ) -> ToolResult:
+ ) -> MCPToolResult:
"""Synchronously calls a tool on the MCP server.
This method calls the asynchronous call_tool method on the MCP session
- and converts the result to the ToolResult format.
+ and converts the result to the ToolResult format. If the MCP tool returns
+ structured content, it will be included as the last item in the content array
+ of the returned ToolResult.
Args:
tool_use_id: Unique identifier for this tool use
@@ -183,7 +186,7 @@ def call_tool_sync(
read_timeout_seconds: Optional timeout for the tool call
Returns:
- ToolResult: The result of the tool call
+ MCPToolResult: The result of the tool call
"""
self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id)
if not self._is_session_active():
@@ -205,11 +208,11 @@ async def call_tool_async(
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
- ) -> ToolResult:
+ ) -> MCPToolResult:
"""Asynchronously calls a tool on the MCP server.
This method calls the asynchronous call_tool method on the MCP session
- and converts the result to the ToolResult format.
+ and converts the result to the MCPToolResult format.
Args:
tool_use_id: Unique identifier for this tool use
@@ -218,7 +221,7 @@ async def call_tool_async(
read_timeout_seconds: Optional timeout for the tool call
Returns:
- ToolResult: The result of the tool call
+ MCPToolResult: The result of the tool call
"""
self._log_debug_with_thread("calling MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id)
if not self._is_session_active():
@@ -235,15 +238,27 @@ async def _call_tool_async() -> MCPCallToolResult:
logger.exception("tool execution failed")
return self._handle_tool_execution_error(tool_use_id, e)
- def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> ToolResult:
+ def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult:
"""Create error ToolResult with consistent logging."""
- return ToolResult(
+ return MCPToolResult(
status="error",
toolUseId=tool_use_id,
content=[{"text": f"Tool execution failed: {str(exception)}"}],
)
- def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> ToolResult:
+ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> MCPToolResult:
+ """Maps MCP tool result to the agent's MCPToolResult format.
+
+ This method processes the content from the MCP tool call result and converts it to the format
+ expected by the framework.
+
+ Args:
+ tool_use_id: Unique identifier for this tool use
+ call_tool_result: The result from the MCP tool call
+
+ Returns:
+ MCPToolResult: The converted tool result
+ """
self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content))
mapped_content = [
@@ -254,7 +269,15 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes
status: ToolResultStatus = "error" if call_tool_result.isError else "success"
self._log_debug_with_thread("tool execution completed with status: %s", status)
- return ToolResult(status=status, toolUseId=tool_use_id, content=mapped_content)
+ result = MCPToolResult(
+ status=status,
+ toolUseId=tool_use_id,
+ content=mapped_content,
+ )
+ if call_tool_result.structuredContent:
+ result["structuredContent"] = call_tool_result.structuredContent
+
+ return result
async def _async_background_thread(self) -> None:
"""Asynchronous method that runs in the background thread to manage the MCP connection.
diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py
index 30defc585..5fafed5dc 100644
--- a/src/strands/tools/mcp/mcp_types.py
+++ b/src/strands/tools/mcp/mcp_types.py
@@ -1,11 +1,15 @@
"""Type definitions for MCP integration."""
from contextlib import AbstractAsyncContextManager
+from typing import Any, Dict
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.streamable_http import GetSessionIdCallback
from mcp.shared.memory import MessageStream
from mcp.shared.message import SessionMessage
+from typing_extensions import NotRequired
+
+from strands.types.tools import ToolResult
"""
MCPTransport defines the interface for MCP transport implementations. This abstracts
@@ -41,3 +45,19 @@ async def my_transport_implementation():
MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback
]
MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback]
+
+
+class MCPToolResult(ToolResult):
+ """Result of an MCP tool execution.
+
+ Extends the base ToolResult with MCP-specific structured content support.
+ The structuredContent field contains optional JSON data returned by MCP tools
+ that provides structured results beyond the standard text/image/document content.
+
+ Attributes:
+ structuredContent: Optional JSON object containing structured data returned
+ by the MCP tool. This allows MCP tools to return complex data structures
+ that can be processed programmatically by agents or other tools.
+ """
+
+ structuredContent: NotRequired[Dict[str, Any]]
diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py
index 47e028cb9..0a2846adf 100644
--- a/tests/strands/models/test_bedrock.py
+++ b/tests/strands/models/test_bedrock.py
@@ -13,6 +13,7 @@
from strands.models import BedrockModel
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION
from strands.types.exceptions import ModelThrottledException
+from strands.types.tools import ToolSpec
@pytest.fixture
@@ -51,7 +52,7 @@ def model(bedrock_client, model_id):
@pytest.fixture
def messages():
- return [{"role": "user", "content": {"text": "test"}}]
+ return [{"role": "user", "content": [{"text": "test"}]}]
@pytest.fixture
@@ -90,8 +91,12 @@ def inference_config():
@pytest.fixture
-def tool_spec():
- return {"t1": 1}
+def tool_spec() -> ToolSpec:
+ return {
+ "description": "description",
+ "name": "name",
+ "inputSchema": {"key": "val"},
+ }
@pytest.fixture
@@ -750,7 +755,7 @@ async def test_stream_output_no_guardrail_redact(
@pytest.mark.asyncio
-async def test_stream_with_streaming_false(bedrock_client, alist):
+async def test_stream_with_streaming_false(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
@@ -759,8 +764,7 @@ async def test_stream_with_streaming_false(bedrock_client, alist):
# Create model and call stream
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -776,7 +780,7 @@ async def test_stream_with_streaming_false(bedrock_client, alist):
@pytest.mark.asyncio
-async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist):
+async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {
@@ -790,8 +794,7 @@ async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist):
# Create model and call stream
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -808,7 +811,7 @@ async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist):
@pytest.mark.asyncio
-async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist):
+async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {
@@ -828,8 +831,7 @@ async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist):
# Create model and call stream
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -847,7 +849,7 @@ async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist):
@pytest.mark.asyncio
-async def test_stream_and_reasoning_no_signature(bedrock_client, alist):
+async def test_stream_and_reasoning_no_signature(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {
@@ -867,8 +869,7 @@ async def test_stream_and_reasoning_no_signature(bedrock_client, alist):
# Create model and call stream
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -884,7 +885,7 @@ async def test_stream_and_reasoning_no_signature(bedrock_client, alist):
@pytest.mark.asyncio
-async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist):
+async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
@@ -895,8 +896,7 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client
# Create model and call stream
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -919,7 +919,7 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client
@pytest.mark.asyncio
-async def test_stream_input_guardrails(bedrock_client, alist):
+async def test_stream_input_guardrails(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
@@ -937,8 +937,7 @@ async def test_stream_input_guardrails(bedrock_client, alist):
# Create model and call stream
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -970,7 +969,7 @@ async def test_stream_input_guardrails(bedrock_client, alist):
@pytest.mark.asyncio
-async def test_stream_output_guardrails(bedrock_client, alist):
+async def test_stream_output_guardrails(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
@@ -989,8 +988,7 @@ async def test_stream_output_guardrails(bedrock_client, alist):
}
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -1024,7 +1022,7 @@ async def test_stream_output_guardrails(bedrock_client, alist):
@pytest.mark.asyncio
-async def test_stream_output_guardrails_redacts_output(bedrock_client, alist):
+async def test_stream_output_guardrails_redacts_output(bedrock_client, alist, messages):
"""Test stream method with streaming=False."""
bedrock_client.converse.return_value = {
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
@@ -1043,8 +1041,7 @@ async def test_stream_output_guardrails_redacts_output(bedrock_client, alist):
}
model = BedrockModel(model_id="test-model", streaming=False)
- request = {"modelId": "test-model"}
- response = model.stream(request)
+ response = model.stream(messages)
tru_events = await alist(response)
exp_events = [
@@ -1101,7 +1098,7 @@ async def test_structured_output(bedrock_client, model, test_output_model_cls, a
@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)")
@pytest.mark.asyncio
-async def test_add_note_on_client_error(bedrock_client, model, alist):
+async def test_add_note_on_client_error(bedrock_client, model, alist, messages):
"""Test that add_note is called on ClientError with region and model ID information."""
# Mock the client error response
error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}}
@@ -1109,13 +1106,13 @@ async def test_add_note_on_client_error(bedrock_client, model, alist):
# Call the stream method which should catch and add notes to the exception
with pytest.raises(ClientError) as err:
- await alist(model.stream({"modelId": "test-model"}))
+ await alist(model.stream(messages))
assert err.value.__notes__ == ["ā Bedrock region: us-west-2", "ā Model id: m1"]
@pytest.mark.asyncio
-async def test_no_add_note_when_not_available(bedrock_client, model, alist):
+async def test_no_add_note_when_not_available(bedrock_client, model, alist, messages):
"""Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception)."""
# Mock the client error response
error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}}
@@ -1123,12 +1120,12 @@ async def test_no_add_note_when_not_available(bedrock_client, model, alist):
# Call the stream method which should catch and add notes to the exception
with pytest.raises(ClientError):
- await alist(model.stream({"modelId": "test-model"}))
+ await alist(model.stream(messages))
@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)")
@pytest.mark.asyncio
-async def test_add_note_on_access_denied_exception(bedrock_client, model, alist):
+async def test_add_note_on_access_denied_exception(bedrock_client, model, alist, messages):
"""Test that add_note adds documentation link for AccessDeniedException."""
# Mock the client error response for access denied
error_response = {
@@ -1142,7 +1139,7 @@ async def test_add_note_on_access_denied_exception(bedrock_client, model, alist)
# Call the stream method which should catch and add notes to the exception
with pytest.raises(ClientError) as err:
- await alist(model.stream({"modelId": "test-model"}))
+ await alist(model.stream(messages))
assert err.value.__notes__ == [
"ā Bedrock region: us-west-2",
@@ -1154,7 +1151,7 @@ async def test_add_note_on_access_denied_exception(bedrock_client, model, alist)
@pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)")
@pytest.mark.asyncio
-async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist):
+async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist, messages):
"""Test that add_note adds documentation link for ValidationException about on-demand throughput."""
# Mock the client error response for validation exception
error_response = {
@@ -1170,7 +1167,7 @@ async def test_add_note_on_validation_exception_throughput(bedrock_client, model
# Call the stream method which should catch and add notes to the exception
with pytest.raises(ClientError) as err:
- await alist(model.stream({"modelId": "test-model"}))
+ await alist(model.stream(messages))
assert err.value.__notes__ == [
"ā Bedrock region: us-west-2",
@@ -1202,3 +1199,32 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist):
assert "invoking model" in log_text
assert "got response from model" in log_text
assert "finished streaming response from model" in log_text
+
+
+def test_format_request_cleans_tool_result_content_blocks(model, model_id):
+ """Test that format_request cleans toolResult blocks by removing extra fields."""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "toolResult": {
+ "content": [{"text": "Tool output"}],
+ "toolUseId": "tool123",
+ "status": "success",
+ "extraField": "should be removed",
+ "mcpMetadata": {"server": "test"},
+ }
+ },
+ ],
+ }
+ ]
+
+ formatted_request = model.format_request(messages)
+
+ # Verify toolResult only contains allowed fields in the formatted request
+ tool_result = formatted_request["messages"][0]["content"][0]["toolResult"]
+ expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool123", "status": "success"}
+ assert tool_result == expected
+ assert "extraField" not in tool_result
+ assert "mcpMetadata" not in tool_result
diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py
index 6a2fdd00c..3d3792c71 100644
--- a/tests/strands/tools/mcp/test_mcp_client.py
+++ b/tests/strands/tools/mcp/test_mcp_client.py
@@ -8,6 +8,7 @@
from mcp.types import Tool as MCPTool
from strands.tools.mcp import MCPClient
+from strands.tools.mcp.mcp_types import MCPToolResult
from strands.types.exceptions import MCPClientInitializationError
@@ -129,6 +130,8 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_
assert result["toolUseId"] == "test-123"
assert len(result["content"]) == 1
assert result["content"][0]["text"] == "Test message"
+ # No structured content should be present when not provided by MCP
+ assert result.get("structuredContent") is None
def test_call_tool_sync_session_not_active():
@@ -139,6 +142,31 @@ def test_call_tool_sync_session_not_active():
client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
+def test_call_tool_sync_with_structured_content(mock_transport, mock_session):
+ """Test that call_tool_sync correctly handles structured content."""
+ mock_content = MCPTextContent(type="text", text="Test message")
+ structured_content = {"result": 42, "status": "completed"}
+ mock_session.call_tool.return_value = MCPCallToolResult(
+ isError=False, content=[mock_content], structuredContent=structured_content
+ )
+
+ with MCPClient(mock_transport["transport_callable"]) as client:
+ result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
+
+ mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None)
+
+ assert result["status"] == "success"
+ assert result["toolUseId"] == "test-123"
+ # Content should only contain the text content, not the structured content
+ assert len(result["content"]) == 1
+ assert result["content"][0]["text"] == "Test message"
+ # Structured content should be in its own field
+ assert "structuredContent" in result
+ assert result["structuredContent"] == structured_content
+ assert result["structuredContent"]["result"] == 42
+ assert result["structuredContent"]["status"] == "completed"
+
+
def test_call_tool_sync_exception(mock_transport, mock_session):
"""Test that call_tool_sync correctly handles exceptions."""
mock_session.call_tool.side_effect = Exception("Test exception")
@@ -312,6 +340,45 @@ def test_enter_with_initialization_exception(mock_transport):
client.start()
+def test_mcp_tool_result_type():
+ """Test that MCPToolResult extends ToolResult correctly."""
+ # Test basic ToolResult functionality
+ result = MCPToolResult(status="success", toolUseId="test-123", content=[{"text": "Test message"}])
+
+ assert result["status"] == "success"
+ assert result["toolUseId"] == "test-123"
+ assert result["content"][0]["text"] == "Test message"
+
+ # Test that structuredContent is optional
+ assert "structuredContent" not in result or result.get("structuredContent") is None
+
+ # Test with structuredContent
+ result_with_structured = MCPToolResult(
+ status="success", toolUseId="test-456", content=[{"text": "Test message"}], structuredContent={"key": "value"}
+ )
+
+ assert result_with_structured["structuredContent"] == {"key": "value"}
+
+
+def test_call_tool_sync_without_structured_content(mock_transport, mock_session):
+ """Test that call_tool_sync works correctly when no structured content is provided."""
+ mock_content = MCPTextContent(type="text", text="Test message")
+ mock_session.call_tool.return_value = MCPCallToolResult(
+ isError=False,
+ content=[mock_content], # No structuredContent
+ )
+
+ with MCPClient(mock_transport["transport_callable"]) as client:
+ result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
+
+ assert result["status"] == "success"
+ assert result["toolUseId"] == "test-123"
+ assert len(result["content"]) == 1
+ assert result["content"][0]["text"] == "Test message"
+ # structuredContent should be None when not provided by MCP
+ assert result.get("structuredContent") is None
+
+
def test_exception_when_future_not_running():
"""Test exception handling when the future is not running."""
# Create a client.with a mock transport
diff --git a/tests_integ/echo_server.py b/tests_integ/echo_server.py
index d309607a8..52223792c 100644
--- a/tests_integ/echo_server.py
+++ b/tests_integ/echo_server.py
@@ -2,7 +2,7 @@
Echo Server for MCP Integration Testing
This module implements a simple echo server using the Model Context Protocol (MCP).
-It provides a basic tool that echoes back any input string, which is useful for
+It provides basic tools that echo back input strings and structured content, which is useful for
testing the MCP communication flow and validating that messages are properly
transmitted between the client and server.
@@ -15,6 +15,8 @@
$ python echo_server.py
"""
+from typing import Any, Dict
+
from mcp.server import FastMCP
@@ -22,16 +24,22 @@ def start_echo_server():
"""
Initialize and start the MCP echo server.
- Creates a FastMCP server instance with a single 'echo' tool that returns
- any input string back to the caller. The server uses stdio transport
+ Creates a FastMCP server instance with tools that return
+ input strings and structured content back to the caller. The server uses stdio transport
for communication.
+
"""
mcp = FastMCP("Echo Server")
- @mcp.tool(description="Echos response back to the user")
+ @mcp.tool(description="Echos response back to the user", structured_output=False)
def echo(to_echo: str) -> str:
return to_echo
+ # FastMCP automatically constructs structured output schema from method signature
+ @mcp.tool(description="Echos response back with structured content", structured_output=True)
+ def echo_with_structured_content(to_echo: str) -> Dict[str, Any]:
+ return {"echoed": to_echo}
+
mcp.run(transport="stdio")
diff --git a/tests_integ/test_mcp_client.py b/tests_integ/test_mcp_client.py
index 9163f625d..ebd4f5896 100644
--- a/tests_integ/test_mcp_client.py
+++ b/tests_integ/test_mcp_client.py
@@ -1,4 +1,5 @@
import base64
+import json
import os
import threading
import time
@@ -87,6 +88,24 @@ def test_mcp_client():
]
)
+ tool_use_id = "test-structured-content-123"
+ result = stdio_mcp_client.call_tool_sync(
+ tool_use_id=tool_use_id,
+ name="echo_with_structured_content",
+ arguments={"to_echo": "STRUCTURED_DATA_TEST"},
+ )
+
+ # With the new MCPToolResult, structured content is in its own field
+ assert "structuredContent" in result
+ assert result["structuredContent"]["result"] == {"echoed": "STRUCTURED_DATA_TEST"}
+
+ # Verify the result is an MCPToolResult (at runtime it's just a dict, but type-wise it should be MCPToolResult)
+ assert result["status"] == "success"
+ assert result["toolUseId"] == tool_use_id
+
+ assert len(result["content"]) == 1
+ assert json.loads(result["content"][0]["text"]) == {"echoed": "STRUCTURED_DATA_TEST"}
+
def test_can_reuse_mcp_client():
stdio_mcp_client = MCPClient(
@@ -103,6 +122,64 @@ def test_can_reuse_mcp_client():
assert any([block["name"] == "echo" for block in tool_use_content_blocks])
+@pytest.mark.asyncio
+async def test_mcp_client_async_structured_content():
+ """Test that async MCP client calls properly handle structured content.
+
+ This test demonstrates how tools configure structured output: FastMCP automatically
+ constructs structured output schema from method signature when structured_output=True
+ is set in the @mcp.tool decorator. The return type annotation defines the structure
+ that appears in structuredContent field.
+ """
+ stdio_mcp_client = MCPClient(
+ lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
+ )
+
+ with stdio_mcp_client:
+ tool_use_id = "test-async-structured-content-456"
+ result = await stdio_mcp_client.call_tool_async(
+ tool_use_id=tool_use_id,
+ name="echo_with_structured_content",
+ arguments={"to_echo": "ASYNC_STRUCTURED_TEST"},
+ )
+
+ # Verify structured content is in its own field
+ assert "structuredContent" in result
+ # "result" nesting is not part of the MCP Structured Content specification,
+ # but rather a FastMCP implementation detail
+ assert result["structuredContent"]["result"] == {"echoed": "ASYNC_STRUCTURED_TEST"}
+
+ # Verify basic MCPToolResult structure
+ assert result["status"] in ["success", "error"]
+ assert result["toolUseId"] == tool_use_id
+
+ assert len(result["content"]) == 1
+ assert json.loads(result["content"][0]["text"]) == {"echoed": "ASYNC_STRUCTURED_TEST"}
+
+
+def test_mcp_client_without_structured_content():
+ """Test that MCP client works correctly when tools don't return structured content."""
+ stdio_mcp_client = MCPClient(
+ lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
+ )
+
+ with stdio_mcp_client:
+ tool_use_id = "test-no-structured-content-789"
+ result = stdio_mcp_client.call_tool_sync(
+ tool_use_id=tool_use_id,
+ name="echo", # This tool doesn't return structured content
+ arguments={"to_echo": "SIMPLE_ECHO_TEST"},
+ )
+
+ # Verify no structured content when tool doesn't provide it
+ assert result.get("structuredContent") is None
+
+ # Verify basic result structure
+ assert result["status"] == "success"
+ assert result["toolUseId"] == tool_use_id
+ assert result["content"] == [{"text": "SIMPLE_ECHO_TEST"}]
+
+
@pytest.mark.skipif(
condition=os.environ.get("GITHUB_ACTIONS") == "true",
reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue",
diff --git a/tests_integ/test_mcp_client_structured_content_with_hooks.py b/tests_integ/test_mcp_client_structured_content_with_hooks.py
new file mode 100644
index 000000000..ca2468c48
--- /dev/null
+++ b/tests_integ/test_mcp_client_structured_content_with_hooks.py
@@ -0,0 +1,65 @@
+"""Integration test demonstrating hooks system with MCP client structured content tool.
+
+This test shows how to use the hooks system to capture and inspect tool invocation
+results, specifically testing the echo_with_structured_content tool from echo_server.
+"""
+
+import json
+
+from mcp import StdioServerParameters, stdio_client
+
+from strands import Agent
+from strands.experimental.hooks import AfterToolInvocationEvent
+from strands.hooks import HookProvider, HookRegistry
+from strands.tools.mcp.mcp_client import MCPClient
+
+
+class StructuredContentHookProvider(HookProvider):
+ """Hook provider that captures structured content tool results."""
+
+ def __init__(self):
+ self.captured_result = None
+
+ def register_hooks(self, registry: HookRegistry) -> None:
+ """Register callback for after tool invocation events."""
+ registry.add_callback(AfterToolInvocationEvent, self.on_after_tool_invocation)
+
+ def on_after_tool_invocation(self, event: AfterToolInvocationEvent) -> None:
+ """Capture structured content tool results."""
+ if event.tool_use["name"] == "echo_with_structured_content":
+ self.captured_result = event.result
+
+
+def test_mcp_client_hooks_structured_content():
+ """Test using hooks to inspect echo_with_structured_content tool result."""
+ # Create hook provider to capture tool result
+ hook_provider = StructuredContentHookProvider()
+
+ # Set up MCP client for echo server
+ stdio_mcp_client = MCPClient(
+ lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
+ )
+
+ with stdio_mcp_client:
+ # Create agent with MCP tools and hook provider
+ agent = Agent(tools=stdio_mcp_client.list_tools_sync(), hooks=[hook_provider])
+
+ # Test structured content functionality
+ test_data = "HOOKS_TEST_DATA"
+ agent(f"Use the echo_with_structured_content tool to echo: {test_data}")
+
+ # Verify hook captured the tool result
+ assert hook_provider.captured_result is not None
+ result = hook_provider.captured_result
+
+ # Verify basic result structure
+ assert result["status"] == "success"
+ assert len(result["content"]) == 1
+
+ # Verify structured content is present and correct
+ assert "structuredContent" in result
+ assert result["structuredContent"]["result"] == {"echoed": test_data}
+
+ # Verify text content matches structured content
+ text_content = json.loads(result["content"][0]["text"])
+ assert text_content == {"echoed": test_data}
From b13c5c5492e7745acb86d23eb215acdce0120361 Mon Sep 17 00:00:00 2001
From: Ketan Suhaas Saichandran <55935983+Ketansuhaas@users.noreply.github.com>
Date: Wed, 30 Jul 2025 08:59:29 -0400
Subject: [PATCH 015/104] feat(mcp): Add list_prompts, get_prompt methods
(#160)
Co-authored-by: ketan-clairyon
Co-authored-by: Dean Schmigelski
---
src/strands/tools/mcp/mcp_client.py | 49 +++++++++++++
tests/strands/tools/mcp/test_mcp_client.py | 62 ++++++++++++++++
tests_integ/test_mcp_client.py | 83 +++++++++++++++++++---
3 files changed, 184 insertions(+), 10 deletions(-)
diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py
index 784636fd0..8c21baa4a 100644
--- a/src/strands/tools/mcp/mcp_client.py
+++ b/src/strands/tools/mcp/mcp_client.py
@@ -20,6 +20,7 @@
from mcp import ClientSession, ListToolsResult
from mcp.types import CallToolResult as MCPCallToolResult
+from mcp.types import GetPromptResult, ListPromptsResult
from mcp.types import ImageContent as MCPImageContent
from mcp.types import TextContent as MCPTextContent
@@ -165,6 +166,54 @@ async def _list_tools_async() -> ListToolsResult:
self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools))
return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor)
+ def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromptsResult:
+ """Synchronously retrieves the list of available prompts from the MCP server.
+
+ This method calls the asynchronous list_prompts method on the MCP session
+ and returns the raw ListPromptsResult with pagination support.
+
+ Args:
+ pagination_token: Optional token for pagination
+
+ Returns:
+ ListPromptsResult: The raw MCP response containing prompts and pagination info
+ """
+ self._log_debug_with_thread("listing MCP prompts synchronously")
+ if not self._is_session_active():
+ raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
+
+ async def _list_prompts_async() -> ListPromptsResult:
+ return await self._background_thread_session.list_prompts(cursor=pagination_token)
+
+ list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result()
+ self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts))
+ for prompt in list_prompts_result.prompts:
+ self._log_debug_with_thread(prompt.name)
+
+ return list_prompts_result
+
+ def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResult:
+ """Synchronously retrieves a prompt from the MCP server.
+
+ Args:
+ prompt_id: The ID of the prompt to retrieve
+ args: Optional arguments to pass to the prompt
+
+ Returns:
+ GetPromptResult: The prompt response from the MCP server
+ """
+ self._log_debug_with_thread("getting MCP prompt synchronously")
+ if not self._is_session_active():
+ raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
+
+ async def _get_prompt_async() -> GetPromptResult:
+ return await self._background_thread_session.get_prompt(prompt_id, arguments=args)
+
+ get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result()
+ self._log_debug_with_thread("received prompt from MCP server")
+
+ return get_prompt_result
+
def call_tool_sync(
self,
tool_use_id: str,
diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py
index 3d3792c71..bd88382cd 100644
--- a/tests/strands/tools/mcp/test_mcp_client.py
+++ b/tests/strands/tools/mcp/test_mcp_client.py
@@ -4,6 +4,7 @@
import pytest
from mcp import ListToolsResult
from mcp.types import CallToolResult as MCPCallToolResult
+from mcp.types import GetPromptResult, ListPromptsResult, Prompt, PromptMessage
from mcp.types import TextContent as MCPTextContent
from mcp.types import Tool as MCPTool
@@ -404,3 +405,64 @@ def test_exception_when_future_not_running():
# Verify that set_exception was not called since the future was not running
mock_future.set_exception.assert_not_called()
+
+
+# Prompt Tests - Sync Methods
+
+
+def test_list_prompts_sync(mock_transport, mock_session):
+ """Test that list_prompts_sync correctly retrieves prompts."""
+ mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1")
+ mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt])
+
+ with MCPClient(mock_transport["transport_callable"]) as client:
+ result = client.list_prompts_sync()
+
+ mock_session.list_prompts.assert_called_once_with(cursor=None)
+ assert len(result.prompts) == 1
+ assert result.prompts[0].name == "test_prompt"
+ assert result.nextCursor is None
+
+
+def test_list_prompts_sync_with_pagination_token(mock_transport, mock_session):
+ """Test that list_prompts_sync correctly passes pagination token and returns next cursor."""
+ mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1")
+ mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt], nextCursor="next_page_token")
+
+ with MCPClient(mock_transport["transport_callable"]) as client:
+ result = client.list_prompts_sync(pagination_token="current_page_token")
+
+ mock_session.list_prompts.assert_called_once_with(cursor="current_page_token")
+ assert len(result.prompts) == 1
+ assert result.prompts[0].name == "test_prompt"
+ assert result.nextCursor == "next_page_token"
+
+
+def test_list_prompts_sync_session_not_active():
+ """Test that list_prompts_sync raises an error when session is not active."""
+ client = MCPClient(MagicMock())
+
+ with pytest.raises(MCPClientInitializationError, match="client session is not running"):
+ client.list_prompts_sync()
+
+
+def test_get_prompt_sync(mock_transport, mock_session):
+ """Test that get_prompt_sync correctly retrieves a prompt."""
+ mock_message = PromptMessage(role="user", content=MCPTextContent(type="text", text="This is a test prompt"))
+ mock_session.get_prompt.return_value = GetPromptResult(messages=[mock_message])
+
+ with MCPClient(mock_transport["transport_callable"]) as client:
+ result = client.get_prompt_sync("test_prompt_id", {"key": "value"})
+
+ mock_session.get_prompt.assert_called_once_with("test_prompt_id", arguments={"key": "value"})
+ assert len(result.messages) == 1
+ assert result.messages[0].role == "user"
+ assert result.messages[0].content.text == "This is a test prompt"
+
+
+def test_get_prompt_sync_session_not_active():
+ """Test that get_prompt_sync raises an error when session is not active."""
+ client = MCPClient(MagicMock())
+
+ with pytest.raises(MCPClientInitializationError, match="client session is not running"):
+ client.get_prompt_sync("test_prompt_id", {})
diff --git a/tests_integ/test_mcp_client.py b/tests_integ/test_mcp_client.py
index ebd4f5896..3de249435 100644
--- a/tests_integ/test_mcp_client.py
+++ b/tests_integ/test_mcp_client.py
@@ -18,18 +18,17 @@
from strands.types.tools import ToolUse
-def start_calculator_server(transport: Literal["sse", "streamable-http"], port=int):
+def start_comprehensive_mcp_server(transport: Literal["sse", "streamable-http"], port=int):
"""
- Initialize and start an MCP calculator server for integration testing.
+ Initialize and start a comprehensive MCP server for integration testing.
- This function creates a FastMCP server instance that provides a simple
- calculator tool for performing addition operations. The server uses
- Server-Sent Events (SSE) transport for communication, making it accessible
- over HTTP.
+ This function creates a FastMCP server instance that provides tools, prompts,
+ and resources all in one server for comprehensive testing. The server uses
+ Server-Sent Events (SSE) or streamable HTTP transport for communication.
"""
from mcp.server import FastMCP
- mcp = FastMCP("Calculator Server", port=port)
+ mcp = FastMCP("Comprehensive MCP Server", port=port)
@mcp.tool(description="Calculator tool which performs calculations")
def calculator(x: int, y: int) -> int:
@@ -44,6 +43,15 @@ def generate_custom_image() -> MCPImageContent:
except Exception as e:
print("Error while generating custom image: {}".format(e))
+ # Prompts
+ @mcp.prompt(description="A greeting prompt template")
+ def greeting_prompt(name: str = "World") -> str:
+ return f"Hello, {name}! How are you today?"
+
+ @mcp.prompt(description="A math problem prompt template")
+ def math_prompt(operation: str = "addition", difficulty: str = "easy") -> str:
+ return f"Create a {difficulty} {operation} math problem and solve it step by step."
+
mcp.run(transport=transport)
@@ -58,8 +66,9 @@ def test_mcp_client():
{'role': 'assistant', 'content': [{'text': '\n\nThe result of adding 1 and 2 is 3.'}]}
""" # noqa: E501
+ # Start comprehensive server with tools, prompts, and resources
server_thread = threading.Thread(
- target=start_calculator_server, kwargs={"transport": "sse", "port": 8000}, daemon=True
+ target=start_comprehensive_mcp_server, kwargs={"transport": "sse", "port": 8000}, daemon=True
)
server_thread.start()
time.sleep(2) # wait for server to startup completely
@@ -68,8 +77,14 @@ def test_mcp_client():
stdio_mcp_client = MCPClient(
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
)
+
with sse_mcp_client, stdio_mcp_client:
- agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync())
+ # Test Tools functionality
+ sse_tools = sse_mcp_client.list_tools_sync()
+ stdio_tools = stdio_mcp_client.list_tools_sync()
+ all_tools = sse_tools + stdio_tools
+
+ agent = Agent(tools=all_tools)
agent("add 1 and 2, then echo the result back to me")
tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
@@ -88,6 +103,43 @@ def test_mcp_client():
]
)
+ # Test Prompts functionality
+ prompts_result = sse_mcp_client.list_prompts_sync()
+ assert len(prompts_result.prompts) >= 2 # We expect at least greeting and math prompts
+
+ prompt_names = [prompt.name for prompt in prompts_result.prompts]
+ assert "greeting_prompt" in prompt_names
+ assert "math_prompt" in prompt_names
+
+ # Test get_prompt_sync with greeting prompt
+ greeting_result = sse_mcp_client.get_prompt_sync("greeting_prompt", {"name": "Alice"})
+ assert len(greeting_result.messages) > 0
+ prompt_text = greeting_result.messages[0].content.text
+ assert "Hello, Alice!" in prompt_text
+ assert "How are you today?" in prompt_text
+
+ # Test get_prompt_sync with math prompt
+ math_result = sse_mcp_client.get_prompt_sync(
+ "math_prompt", {"operation": "multiplication", "difficulty": "medium"}
+ )
+ assert len(math_result.messages) > 0
+ math_text = math_result.messages[0].content.text
+ assert "multiplication" in math_text
+ assert "medium" in math_text
+ assert "step by step" in math_text
+
+ # Test pagination support for prompts
+ prompts_with_token = sse_mcp_client.list_prompts_sync(pagination_token=None)
+ assert len(prompts_with_token.prompts) >= 0
+
+ # Test pagination support for tools (existing functionality)
+ tools_with_token = sse_mcp_client.list_tools_sync(pagination_token=None)
+ assert len(tools_with_token) >= 0
+
+ # TODO: Add resources testing when resources are implemented
+ # resources_result = sse_mcp_client.list_resources_sync()
+ # assert len(resources_result.resources) >= 0
+
tool_use_id = "test-structured-content-123"
result = stdio_mcp_client.call_tool_sync(
tool_use_id=tool_use_id,
@@ -185,8 +237,9 @@ def test_mcp_client_without_structured_content():
reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue",
)
def test_streamable_http_mcp_client():
+ """Test comprehensive MCP client with streamable HTTP transport."""
server_thread = threading.Thread(
- target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True
+ target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True
)
server_thread.start()
time.sleep(2) # wait for server to startup completely
@@ -196,12 +249,22 @@ def transport_callback() -> MCPTransport:
streamable_http_client = MCPClient(transport_callback)
with streamable_http_client:
+ # Test tools
agent = Agent(tools=streamable_http_client.list_tools_sync())
agent("add 1 and 2 using a calculator")
tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
assert any([block["name"] == "calculator" for block in tool_use_content_blocks])
+ # Test prompts
+ prompts_result = streamable_http_client.list_prompts_sync()
+ assert len(prompts_result.prompts) >= 2
+
+ greeting_result = streamable_http_client.get_prompt_sync("greeting_prompt", {"name": "Charlie"})
+ assert len(greeting_result.messages) > 0
+ prompt_text = greeting_result.messages[0].content.text
+ assert "Hello, Charlie!" in prompt_text
+
def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]:
return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block]
From 3d526f2e254d38bb83b8ec85af56e79e4e1fe33f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E3=81=BF=E3=81=AE=E3=82=8B=E3=82=93?=
<74597894+minorun365@users.noreply.github.com>
Date: Thu, 31 Jul 2025 23:40:25 +0900
Subject: [PATCH 016/104] fix(deps): pin a2a-sdk>=0.2.16 to resolve #572 (#581)
Co-authored-by: Jeremiah
---
pyproject.toml | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 095a38cb0..cdf68e01f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -96,6 +96,7 @@ sagemaker = [
]
a2a = [
+ "a2a-sdk>=0.2.16,<1.0.0",
"a2a-sdk[sql]>=0.2.11,<1.0.0",
"uvicorn>=0.34.2,<1.0.0",
"httpx>=0.28.1,<1.0.0",
@@ -321,4 +322,4 @@ style = [
["instruction", ""],
["text", ""],
["disabled", "fg:#858585 italic"]
-]
\ No newline at end of file
+]
From b56a4ff32e93dd74a10c8895cd68528091e88f1b Mon Sep 17 00:00:00 2001
From: Dean Schmigelski
Date: Fri, 1 Aug 2025 09:42:35 -0400
Subject: [PATCH 017/104] chore: pin a2a to a minor version while it is still
in beta (#586)
---
pyproject.toml | 6 +++---
src/strands/multiagent/a2a/executor.py | 2 +-
tests/strands/multiagent/a2a/test_executor.py | 16 ++++++++--------
tests/strands/multiagent/a2a/test_server.py | 4 ++--
4 files changed, 14 insertions(+), 14 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index cdf68e01f..586a956af 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -96,8 +96,8 @@ sagemaker = [
]
a2a = [
- "a2a-sdk>=0.2.16,<1.0.0",
- "a2a-sdk[sql]>=0.2.11,<1.0.0",
+ "a2a-sdk>=0.3.0,<0.4.0",
+ "a2a-sdk[sql]>=0.3.0,<0.4.0",
"uvicorn>=0.34.2,<1.0.0",
"httpx>=0.28.1,<1.0.0",
"fastapi>=0.115.12,<1.0.0",
@@ -143,7 +143,7 @@ all = [
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
# a2a
- "a2a-sdk[sql]>=0.2.11,<1.0.0",
+ "a2a-sdk[sql]>=0.3.0,<0.4.0",
"uvicorn>=0.34.2,<1.0.0",
"httpx>=0.28.1,<1.0.0",
"fastapi>=0.115.12,<1.0.0",
diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py
index d65c64aff..5bf9cbfe9 100644
--- a/src/strands/multiagent/a2a/executor.py
+++ b/src/strands/multiagent/a2a/executor.py
@@ -61,7 +61,7 @@ async def execute(
task = new_task(context.message) # type: ignore
await event_queue.enqueue_event(task)
- updater = TaskUpdater(event_queue, task.id, task.contextId)
+ updater = TaskUpdater(event_queue, task.id, task.context_id)
try:
await self._execute_streaming(context, updater)
diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py
index a956cb769..77645fc73 100644
--- a/tests/strands/multiagent/a2a/test_executor.py
+++ b/tests/strands/multiagent/a2a/test_executor.py
@@ -36,7 +36,7 @@ async def mock_stream(user_input):
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
await executor.execute(mock_request_context, mock_event_queue)
@@ -65,7 +65,7 @@ async def mock_stream(user_input):
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
await executor.execute(mock_request_context, mock_event_queue)
@@ -95,7 +95,7 @@ async def mock_stream(user_input):
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
await executor.execute(mock_request_context, mock_event_queue)
@@ -125,7 +125,7 @@ async def mock_stream(user_input):
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
await executor.execute(mock_request_context, mock_event_queue)
@@ -156,7 +156,7 @@ async def mock_stream(user_input):
mock_request_context.current_task = None
with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task:
- mock_new_task.return_value = MagicMock(id="new-task-id", contextId="new-context-id")
+ mock_new_task.return_value = MagicMock(id="new-task-id", context_id="new-context-id")
await executor.execute(mock_request_context, mock_event_queue)
@@ -180,7 +180,7 @@ async def test_execute_streaming_mode_handles_agent_exception(
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
with pytest.raises(ServerError):
@@ -210,7 +210,7 @@ async def test_handle_agent_result_with_none_result(mock_strands_agent, mock_req
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
# Mock TaskUpdater
@@ -235,7 +235,7 @@ async def test_handle_agent_result_with_result_but_no_message(
# Mock the task creation
mock_task = MagicMock()
mock_task.id = "test-task-id"
- mock_task.contextId = "test-context-id"
+ mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
# Mock TaskUpdater
diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py
index fc76b5f1d..a3b47581c 100644
--- a/tests/strands/multiagent/a2a/test_server.py
+++ b/tests/strands/multiagent/a2a/test_server.py
@@ -87,8 +87,8 @@ def test_public_agent_card(mock_strands_agent):
assert card.description == "A test agent for unit testing"
assert card.url == "http://0.0.0.0:9000/"
assert card.version == "0.0.1"
- assert card.defaultInputModes == ["text"]
- assert card.defaultOutputModes == ["text"]
+ assert card.default_input_modes == ["text"]
+ assert card.default_output_modes == ["text"]
assert card.skills == []
assert card.capabilities == a2a_agent.capabilities
From 8b1de4d4cc4f8adc5386bb1a134aabf96e698cdd Mon Sep 17 00:00:00 2001
From: Laith Al-Saadoon <9553966+theagenticguy@users.noreply.github.com>
Date: Fri, 1 Aug 2025 09:23:25 -0500
Subject: [PATCH 018/104] fix: uses new a2a snake_case for lints to pass (#591)
---
src/strands/models/anthropic.py | 2 +-
src/strands/models/bedrock.py | 2 +-
src/strands/session/file_session_manager.py | 3 ++-
src/strands/session/s3_session_manager.py | 3 ++-
4 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py
index 0d734b762..975fca3e9 100644
--- a/src/strands/models/anthropic.py
+++ b/src/strands/models/anthropic.py
@@ -414,7 +414,7 @@ async def structured_output(
stop_reason, messages, _, _ = event["stop"]
if stop_reason != "tool_use":
- raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".")
+ raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')
content = messages["content"]
output_response: dict[str, Any] | None = None
diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py
index 9b36b4244..4ea1453a4 100644
--- a/src/strands/models/bedrock.py
+++ b/src/strands/models/bedrock.py
@@ -631,7 +631,7 @@ async def structured_output(
stop_reason, messages, _, _ = event["stop"]
if stop_reason != "tool_use":
- raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".")
+ raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')
content = messages["content"]
output_response: dict[str, Any] | None = None
diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py
index b32cb00e6..fec2f0761 100644
--- a/src/strands/session/file_session_manager.py
+++ b/src/strands/session/file_session_manager.py
@@ -23,6 +23,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository):
"""File-based session manager for local filesystem storage.
Creates the following filesystem structure for the session storage:
+ ```bash
//
āāā session_/
āāā session.json # Session metadata
@@ -32,7 +33,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository):
āāā messages/
āāā message_.json
āāā message_.json
-
+ ```
"""
def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any):
diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py
index 8f8423828..0cc0a68c1 100644
--- a/src/strands/session/s3_session_manager.py
+++ b/src/strands/session/s3_session_manager.py
@@ -24,6 +24,7 @@ class S3SessionManager(RepositorySessionManager, SessionRepository):
"""S3-based session manager for cloud storage.
Creates the following filesystem structure for the session storage:
+ ```bash
//
āāā session_/
āāā session.json # Session metadata
@@ -33,7 +34,7 @@ class S3SessionManager(RepositorySessionManager, SessionRepository):
āāā messages/
āāā message_.json
āāā message_.json
-
+ ```
"""
def __init__(
From c85464c45715a9d2ef3f9377f59f9e970ee81cf9 Mon Sep 17 00:00:00 2001
From: Dean Schmigelski
Date: Fri, 1 Aug 2025 10:37:17 -0400
Subject: [PATCH 019/104] =?UTF-8?q?fix(event=5Floop):=20raise=20dedicated?=
=?UTF-8?q?=20exception=20when=20encountering=20max=20toke=E2=80=A6=20(#57?=
=?UTF-8?q?6)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* fix(event_loop): raise dedicated exception when encountering max tokens stop reason
* fix: update integ tests
* fix: rename exception message, add to exception, move earlier in cycle
* Update tests_integ/test_max_tokens_reached.py
Co-authored-by: Nick Clegg
* Update tests_integ/test_max_tokens_reached.py
Co-authored-by: Nick Clegg
* linting
---------
Co-authored-by: Nick Clegg
---
src/strands/event_loop/event_loop.py | 26 ++++++++++-
src/strands/types/exceptions.py | 21 +++++++++
tests/strands/event_loop/test_event_loop.py | 52 ++++++++++++++++++++-
tests_integ/test_max_tokens_reached.py | 20 ++++++++
4 files changed, 116 insertions(+), 3 deletions(-)
create mode 100644 tests_integ/test_max_tokens_reached.py
diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py
index ffcb6a5c9..ae21d4c6d 100644
--- a/src/strands/event_loop/event_loop.py
+++ b/src/strands/event_loop/event_loop.py
@@ -28,7 +28,12 @@
from ..telemetry.tracer import get_tracer
from ..tools.executor import run_tools, validate_and_prepare_tools
from ..types.content import Message
-from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
+from ..types.exceptions import (
+ ContextWindowOverflowException,
+ EventLoopException,
+ MaxTokensReachedException,
+ ModelThrottledException,
+)
from ..types.streaming import Metrics, StopReason
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
from .streaming import stream_messages
@@ -187,6 +192,22 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
raise e
try:
+ if stop_reason == "max_tokens":
+ """
+ Handle max_tokens limit reached by the model.
+
+ When the model reaches its maximum token limit, this represents a potentially unrecoverable
+ state where the model's response was truncated. By default, Strands fails hard with an
+ MaxTokensReachedException to maintain consistency with other failure types.
+ """
+ raise MaxTokensReachedException(
+ message=(
+ "Agent has reached an unrecoverable state due to max_tokens limit. "
+ "For more information see: "
+ "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
+ ),
+ incomplete_message=message,
+ )
# Add message in trace and mark the end of the stream messages trace
stream_trace.add_message(message)
stream_trace.end()
@@ -231,7 +252,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
# Don't yield or log the exception - we already did it when we
# raised the exception and we don't need that duplication.
raise
- except ContextWindowOverflowException as e:
+ except (ContextWindowOverflowException, MaxTokensReachedException) as e:
+ # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException
if cycle_span:
tracer.end_span_with_error(cycle_span, str(e), e)
raise e
diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py
index 4bd3fd88e..71ea28b9f 100644
--- a/src/strands/types/exceptions.py
+++ b/src/strands/types/exceptions.py
@@ -2,6 +2,8 @@
from typing import Any
+from strands.types.content import Message
+
class EventLoopException(Exception):
"""Exception raised by the event loop."""
@@ -18,6 +20,25 @@ def __init__(self, original_exception: Exception, request_state: Any = None) ->
super().__init__(str(original_exception))
+class MaxTokensReachedException(Exception):
+ """Exception raised when the model reaches its maximum token generation limit.
+
+ This exception is raised when the model stops generating tokens because it has reached the maximum number of
+ tokens allowed for output generation. This can occur when the model's max_tokens parameter is set too low for
+ the complexity of the response, or when the model naturally reaches its configured output limit during generation.
+ """
+
+ def __init__(self, message: str, incomplete_message: Message):
+ """Initialize the exception with an error message and the incomplete message object.
+
+ Args:
+ message: The error message describing the token limit issue
+ incomplete_message: The valid Message object with incomplete content due to token limits
+ """
+ self.incomplete_message = incomplete_message
+ super().__init__(message)
+
+
class ContextWindowOverflowException(Exception):
"""Exception raised when the context window is exceeded.
diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py
index 1ac2f8258..3886df8b9 100644
--- a/tests/strands/event_loop/test_event_loop.py
+++ b/tests/strands/event_loop/test_event_loop.py
@@ -19,7 +19,12 @@
)
from strands.telemetry.metrics import EventLoopMetrics
from strands.tools.registry import ToolRegistry
-from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
+from strands.types.exceptions import (
+ ContextWindowOverflowException,
+ EventLoopException,
+ MaxTokensReachedException,
+ ModelThrottledException,
+)
from tests.fixtures.mock_hook_provider import MockHookProvider
@@ -556,6 +561,51 @@ async def test_event_loop_tracing_with_model_error(
mock_tracer.end_span_with_error.assert_called_once_with(model_span, "Input too long", model.stream.side_effect)
+@pytest.mark.asyncio
+async def test_event_loop_cycle_max_tokens_exception(
+ agent,
+ model,
+ agenerator,
+ alist,
+):
+ """Test that max_tokens stop reason raises MaxTokensReachedException."""
+
+ # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495
+ model.stream.return_value = agenerator(
+ [
+ {
+ "contentBlockStart": {
+ "start": {
+ "toolUse": {},
+ },
+ },
+ },
+ {"contentBlockStop": {}},
+ {"messageStop": {"stopReason": "max_tokens"}},
+ ]
+ )
+
+ # Call event_loop_cycle, expecting it to raise MaxTokensReachedException
+ with pytest.raises(MaxTokensReachedException) as exc_info:
+ stream = strands.event_loop.event_loop.event_loop_cycle(
+ agent=agent,
+ invocation_state={},
+ )
+ await alist(stream)
+
+ # Verify the exception message contains the expected content
+ expected_message = (
+ "Agent has reached an unrecoverable state due to max_tokens limit. "
+ "For more information see: "
+ "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
+ )
+ assert str(exc_info.value) == expected_message
+
+ # Verify that the message has not been appended to the messages array
+ assert len(agent.messages) == 1
+ assert exc_info.value.incomplete_message not in agent.messages
+
+
@patch("strands.event_loop.event_loop.get_tracer")
@pytest.mark.asyncio
async def test_event_loop_tracing_with_tool_execution(
diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py
new file mode 100644
index 000000000..d9c2817b3
--- /dev/null
+++ b/tests_integ/test_max_tokens_reached.py
@@ -0,0 +1,20 @@
+import pytest
+
+from strands import Agent, tool
+from strands.models.bedrock import BedrockModel
+from strands.types.exceptions import MaxTokensReachedException
+
+
+@tool
+def story_tool(story: str) -> str:
+ return story
+
+
+def test_context_window_overflow():
+ model = BedrockModel(max_tokens=100)
+ agent = Agent(model=model, tools=[story_tool])
+
+ with pytest.raises(MaxTokensReachedException):
+ agent("Tell me a story!")
+
+ assert len(agent.messages) == 1
From 34d499aeea8ddb933c73711b1371704d4de8c9ba Mon Sep 17 00:00:00 2001
From: poshinchen
Date: Tue, 5 Aug 2025 12:39:21 -0400
Subject: [PATCH 020/104] fix(telemetry): added mcp tracing context propagation
(#569)
---
src/strands/tools/mcp/mcp_client.py | 2 +
src/strands/tools/mcp/mcp_instrumentation.py | 322 ++++++++++++
.../tools/mcp/test_mcp_instrumentation.py | 491 ++++++++++++++++++
3 files changed, 815 insertions(+)
create mode 100644 src/strands/tools/mcp/mcp_instrumentation.py
create mode 100644 tests/strands/tools/mcp/test_mcp_instrumentation.py
diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py
index 8c21baa4a..c1aa96df3 100644
--- a/src/strands/tools/mcp/mcp_client.py
+++ b/src/strands/tools/mcp/mcp_client.py
@@ -29,6 +29,7 @@
from ...types.media import ImageFormat
from ...types.tools import ToolResultContent, ToolResultStatus
from .mcp_agent_tool import MCPAgentTool
+from .mcp_instrumentation import mcp_instrumentation
from .mcp_types import MCPToolResult, MCPTransport
logger = logging.getLogger(__name__)
@@ -68,6 +69,7 @@ def __init__(self, transport_callable: Callable[[], MCPTransport]):
Args:
transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple
"""
+ mcp_instrumentation()
self._session_id = uuid.uuid4()
self._log_debug_with_thread("initializing MCPClient connection")
self._init_future: futures.Future[None] = futures.Future() # Main thread blocks until future completes
diff --git a/src/strands/tools/mcp/mcp_instrumentation.py b/src/strands/tools/mcp/mcp_instrumentation.py
new file mode 100644
index 000000000..338721db5
--- /dev/null
+++ b/src/strands/tools/mcp/mcp_instrumentation.py
@@ -0,0 +1,322 @@
+"""OpenTelemetry instrumentation for Model Context Protocol (MCP) tracing.
+
+Enables distributed tracing across MCP client-server boundaries by injecting
+OpenTelemetry context into MCP request metadata (_meta field) and extracting
+it on the server side, creating unified traces that span from agent calls
+through MCP tool executions.
+
+Based on: https://github.com/traceloop/openllmetry/tree/main/packages/opentelemetry-instrumentation-mcp
+Related issue: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/246
+"""
+
+from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
+from dataclasses import dataclass
+from typing import Any, AsyncGenerator, Callable, Tuple
+
+from mcp.shared.message import SessionMessage
+from mcp.types import JSONRPCMessage, JSONRPCRequest
+from opentelemetry import context, propagate
+from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper
+
+
+@dataclass(slots=True, frozen=True)
+class ItemWithContext:
+ """Wrapper for items that need to carry OpenTelemetry context.
+
+ Used to preserve tracing context across async boundaries in MCP sessions,
+ ensuring that distributed traces remain connected even when messages are
+ processed asynchronously.
+
+ Attributes:
+ item: The original item being wrapped
+ ctx: The OpenTelemetry context associated with the item
+ """
+
+ item: Any
+ ctx: context.Context
+
+
+def mcp_instrumentation() -> None:
+ """Apply OpenTelemetry instrumentation patches to MCP components.
+
+ This function instruments three key areas of MCP communication:
+ 1. Client-side: Injects tracing context into tool call requests
+ 2. Transport-level: Extracts context from incoming messages
+ 3. Session-level: Manages bidirectional context flow
+
+ The patches enable distributed tracing by:
+ - Adding OpenTelemetry context to the _meta field of MCP requests
+ - Extracting and activating context on the server side
+ - Preserving context across async message processing boundaries
+ """
+
+ def patch_mcp_client(wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any) -> Any:
+ """Patch MCP client to inject OpenTelemetry context into tool calls.
+
+ Intercepts outgoing MCP requests and injects the current OpenTelemetry
+ context into the request's _meta field for tools/call methods. This
+ enables server-side context extraction and trace continuation.
+
+ Args:
+ wrapped: The original function being wrapped
+ instance: The instance the method is being called on
+ args: Positional arguments to the wrapped function
+ kwargs: Keyword arguments to the wrapped function
+
+ Returns:
+ Result of the wrapped function call
+ """
+ if len(args) < 1:
+ return wrapped(*args, **kwargs)
+
+ request = args[0]
+ method = getattr(request.root, "method", None)
+
+ if method != "tools/call":
+ return wrapped(*args, **kwargs)
+
+ try:
+ if hasattr(request.root, "params") and request.root.params:
+ # Handle Pydantic models
+ if hasattr(request.root.params, "model_dump") and hasattr(request.root.params, "model_validate"):
+ params_dict = request.root.params.model_dump()
+ # Add _meta with tracing context
+ meta = params_dict.setdefault("_meta", {})
+ propagate.get_global_textmap().inject(meta)
+
+ # Recreate the Pydantic model with the updated data
+ # This preserves the original model type and avoids serialization warnings
+ params_class = type(request.root.params)
+ try:
+ request.root.params = params_class.model_validate(params_dict)
+ except Exception:
+ # Fallback to dict if model recreation fails
+ request.root.params = params_dict
+
+ elif isinstance(request.root.params, dict):
+ # Handle dict params directly
+ meta = request.root.params.setdefault("_meta", {})
+ propagate.get_global_textmap().inject(meta)
+
+ return wrapped(*args, **kwargs)
+
+ except Exception:
+ return wrapped(*args, **kwargs)
+
+ def transport_wrapper() -> Callable[
+ [Callable[..., Any], Any, Any, Any], _AsyncGeneratorContextManager[tuple[Any, Any]]
+ ]:
+ """Create a wrapper for MCP transport connections.
+
+ Returns a context manager that wraps transport read/write streams
+ with context extraction capabilities. The wrapped reader will
+ automatically extract OpenTelemetry context from incoming messages.
+
+ Returns:
+ An async context manager that yields wrapped transport streams
+ """
+
+ @asynccontextmanager
+ async def traced_method(
+ wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any
+ ) -> AsyncGenerator[Tuple[Any, Any], None]:
+ async with wrapped(*args, **kwargs) as result:
+ try:
+ read_stream, write_stream = result
+ except ValueError:
+ read_stream, write_stream, _ = result
+ yield TransportContextExtractingReader(read_stream), write_stream
+
+ return traced_method
+
+ def session_init_wrapper() -> Callable[[Any, Any, Tuple[Any, ...], dict[str, Any]], None]:
+ """Create a wrapper for MCP session initialization.
+
+ Wraps session message streams to enable bidirectional context flow.
+ The reader extracts and activates context, while the writer preserves
+ context for async processing.
+
+ Returns:
+ A function that wraps session initialization
+ """
+
+ def traced_method(
+ wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: dict[str, Any]
+ ) -> None:
+ wrapped(*args, **kwargs)
+ reader = getattr(instance, "_incoming_message_stream_reader", None)
+ writer = getattr(instance, "_incoming_message_stream_writer", None)
+ if reader and writer:
+ instance._incoming_message_stream_reader = SessionContextAttachingReader(reader)
+ instance._incoming_message_stream_writer = SessionContextSavingWriter(writer)
+
+ return traced_method
+
+ # Apply patches
+ wrap_function_wrapper("mcp.shared.session", "BaseSession.send_request", patch_mcp_client)
+
+ register_post_import_hook(
+ lambda _: wrap_function_wrapper(
+ "mcp.server.streamable_http", "StreamableHTTPServerTransport.connect", transport_wrapper()
+ ),
+ "mcp.server.streamable_http",
+ )
+
+ register_post_import_hook(
+ lambda _: wrap_function_wrapper("mcp.server.session", "ServerSession.__init__", session_init_wrapper()),
+ "mcp.server.session",
+ )
+
+
+class TransportContextExtractingReader(ObjectProxy):
+ """A proxy reader that extracts OpenTelemetry context from MCP messages.
+
+ Wraps an async message stream reader to automatically extract and activate
+ OpenTelemetry context from the _meta field of incoming MCP requests. This
+ enables server-side trace continuation from client-injected context.
+
+ The reader handles both SessionMessage and JSONRPCMessage formats, and
+ supports both dict and Pydantic model parameter structures.
+ """
+
+ def __init__(self, wrapped: Any) -> None:
+ """Initialize the context-extracting reader.
+
+ Args:
+ wrapped: The original async stream reader to wrap
+ """
+ super().__init__(wrapped)
+
+ async def __aenter__(self) -> Any:
+ """Enter the async context manager by delegating to the wrapped object."""
+ return await self.__wrapped__.__aenter__()
+
+ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
+ """Exit the async context manager by delegating to the wrapped object."""
+ return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
+
+ async def __aiter__(self) -> AsyncGenerator[Any, None]:
+ """Iterate over messages, extracting and activating context as needed.
+
+ For each incoming message, checks if it contains tracing context in
+ the _meta field. If found, extracts and activates the context for
+ the duration of message processing, then properly detaches it.
+
+ Yields:
+ Messages from the wrapped stream, processed under the appropriate
+ OpenTelemetry context
+ """
+ async for item in self.__wrapped__:
+ if isinstance(item, SessionMessage):
+ request = item.message.root
+ elif type(item) is JSONRPCMessage:
+ request = item.root
+ else:
+ yield item
+ continue
+
+ if isinstance(request, JSONRPCRequest) and request.params:
+ # Handle both dict and Pydantic model params
+ if hasattr(request.params, "get"):
+ # Dict-like access
+ meta = request.params.get("_meta")
+ elif hasattr(request.params, "_meta"):
+ # Direct attribute access for Pydantic models
+ meta = getattr(request.params, "_meta", None)
+ else:
+ meta = None
+
+ if meta:
+ extracted_context = propagate.extract(meta)
+ restore = context.attach(extracted_context)
+ try:
+ yield item
+ continue
+ finally:
+ context.detach(restore)
+ yield item
+
+
+class SessionContextSavingWriter(ObjectProxy):
+ """A proxy writer that preserves OpenTelemetry context with outgoing items.
+
+ Wraps an async message stream writer to capture the current OpenTelemetry
+ context and associate it with outgoing items. This enables context
+ preservation across async boundaries in MCP session processing.
+ """
+
+ def __init__(self, wrapped: Any) -> None:
+ """Initialize the context-saving writer.
+
+ Args:
+ wrapped: The original async stream writer to wrap
+ """
+ super().__init__(wrapped)
+
+ async def __aenter__(self) -> Any:
+ """Enter the async context manager by delegating to the wrapped object."""
+ return await self.__wrapped__.__aenter__()
+
+ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
+ """Exit the async context manager by delegating to the wrapped object."""
+ return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
+
+ async def send(self, item: Any) -> Any:
+ """Send an item while preserving the current OpenTelemetry context.
+
+ Captures the current context and wraps the item with it, enabling
+ the receiving side to restore the appropriate tracing context.
+
+ Args:
+ item: The item to send through the stream
+
+ Returns:
+ Result of sending the wrapped item
+ """
+ ctx = context.get_current()
+ return await self.__wrapped__.send(ItemWithContext(item, ctx))
+
+
+class SessionContextAttachingReader(ObjectProxy):
+ """A proxy reader that restores OpenTelemetry context from wrapped items.
+
+ Wraps an async message stream reader to detect ItemWithContext instances
+ and restore their associated OpenTelemetry context during processing.
+ This completes the context preservation cycle started by SessionContextSavingWriter.
+ """
+
+ def __init__(self, wrapped: Any) -> None:
+ """Initialize the context-attaching reader.
+
+ Args:
+ wrapped: The original async stream reader to wrap
+ """
+ super().__init__(wrapped)
+
+ async def __aenter__(self) -> Any:
+ """Enter the async context manager by delegating to the wrapped object."""
+ return await self.__wrapped__.__aenter__()
+
+ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
+ """Exit the async context manager by delegating to the wrapped object."""
+ return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
+
+ async def __aiter__(self) -> AsyncGenerator[Any, None]:
+ """Iterate over items, restoring context for ItemWithContext instances.
+
+ For items wrapped with context, temporarily activates the associated
+ OpenTelemetry context during processing, then properly detaches it.
+ Regular items are yielded without context modification.
+
+ Yields:
+ Unwrapped items processed under their associated OpenTelemetry context
+ """
+ async for item in self.__wrapped__:
+ if isinstance(item, ItemWithContext):
+ restore = context.attach(item.ctx)
+ try:
+ yield item.item
+ finally:
+ context.detach(restore)
+ else:
+ yield item
diff --git a/tests/strands/tools/mcp/test_mcp_instrumentation.py b/tests/strands/tools/mcp/test_mcp_instrumentation.py
new file mode 100644
index 000000000..61a485777
--- /dev/null
+++ b/tests/strands/tools/mcp/test_mcp_instrumentation.py
@@ -0,0 +1,491 @@
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from mcp.shared.message import SessionMessage
+from mcp.types import JSONRPCMessage, JSONRPCRequest
+from opentelemetry import context, propagate
+
+from strands.tools.mcp.mcp_instrumentation import (
+ ItemWithContext,
+ SessionContextAttachingReader,
+ SessionContextSavingWriter,
+ TransportContextExtractingReader,
+ mcp_instrumentation,
+)
+
+
+class TestItemWithContext:
+ def test_item_with_context_creation(self):
+ """Test that ItemWithContext correctly stores item and context."""
+ test_item = {"test": "data"}
+ test_context = context.get_current()
+
+ wrapped = ItemWithContext(test_item, test_context)
+
+ assert wrapped.item == test_item
+ assert wrapped.ctx == test_context
+
+
+class TestTransportContextExtractingReader:
+ @pytest.fixture
+ def mock_wrapped_reader(self):
+ """Create a mock wrapped reader."""
+ mock_reader = AsyncMock()
+ mock_reader.__aenter__ = AsyncMock(return_value=mock_reader)
+ mock_reader.__aexit__ = AsyncMock()
+ return mock_reader
+
+ def test_init(self, mock_wrapped_reader):
+ """Test reader initialization."""
+ reader = TransportContextExtractingReader(mock_wrapped_reader)
+ assert reader.__wrapped__ == mock_wrapped_reader
+
+ @pytest.mark.asyncio
+ async def test_context_manager_methods(self, mock_wrapped_reader):
+ """Test async context manager methods delegate correctly."""
+ reader = TransportContextExtractingReader(mock_wrapped_reader)
+
+ await reader.__aenter__()
+ mock_wrapped_reader.__aenter__.assert_called_once()
+
+ await reader.__aexit__(None, None, None)
+ mock_wrapped_reader.__aexit__.assert_called_once_with(None, None, None)
+
+ @pytest.mark.asyncio
+ async def test_aiter_with_session_message_and_dict_meta(self, mock_wrapped_reader):
+ """Test context extraction from SessionMessage with dict params containing _meta."""
+ # Create mock message with dict params containing _meta
+ mock_request = MagicMock(spec=JSONRPCRequest)
+ mock_request.params = {"_meta": {"traceparent": "test-trace-id"}, "other": "data"}
+
+ mock_message = MagicMock()
+ mock_message.root = mock_request
+
+ mock_session_message = MagicMock(spec=SessionMessage)
+ mock_session_message.message = mock_message
+
+ async def async_iter():
+ for item in [mock_session_message]:
+ yield item
+
+ mock_wrapped_reader.__aiter__ = lambda self: async_iter()
+
+ reader = TransportContextExtractingReader(mock_wrapped_reader)
+
+ with (
+ patch.object(propagate, "extract") as mock_extract,
+ patch.object(context, "attach") as mock_attach,
+ patch.object(context, "detach") as mock_detach,
+ ):
+ mock_context = MagicMock()
+ mock_extract.return_value = mock_context
+ mock_token = MagicMock()
+ mock_attach.return_value = mock_token
+
+ items = []
+ async for item in reader:
+ items.append(item)
+
+ assert len(items) == 1
+ assert items[0] == mock_session_message
+
+ mock_extract.assert_called_once_with({"traceparent": "test-trace-id"})
+ mock_attach.assert_called_once_with(mock_context)
+ mock_detach.assert_called_once_with(mock_token)
+
+ @pytest.mark.asyncio
+ async def test_aiter_with_session_message_and_pydantic_meta(self, mock_wrapped_reader):
+ """Test context extraction from SessionMessage with Pydantic params having _meta attribute."""
+ # Create mock message with Pydantic-style params
+ mock_request = MagicMock(spec=JSONRPCRequest)
+
+ # Create a mock params object that doesn't have 'get' method but has '_meta' attribute
+ mock_params = MagicMock()
+ # Remove the get method to simulate Pydantic model behavior
+ del mock_params.get
+ mock_params._meta = {"traceparent": "test-trace-id"}
+ mock_request.params = mock_params
+
+ mock_message = MagicMock()
+ mock_message.root = mock_request
+
+ mock_session_message = MagicMock(spec=SessionMessage)
+ mock_session_message.message = mock_message
+
+ async def async_iter():
+ for item in [mock_session_message]:
+ yield item
+
+ mock_wrapped_reader.__aiter__ = lambda self: async_iter()
+
+ reader = TransportContextExtractingReader(mock_wrapped_reader)
+
+ with (
+ patch.object(propagate, "extract") as mock_extract,
+ patch.object(context, "attach") as mock_attach,
+ patch.object(context, "detach") as mock_detach,
+ ):
+ mock_context = MagicMock()
+ mock_extract.return_value = mock_context
+ mock_token = MagicMock()
+ mock_attach.return_value = mock_token
+
+ items = []
+ async for item in reader:
+ items.append(item)
+
+ assert len(items) == 1
+ assert items[0] == mock_session_message
+
+ mock_extract.assert_called_once_with({"traceparent": "test-trace-id"})
+ mock_attach.assert_called_once_with(mock_context)
+ mock_detach.assert_called_once_with(mock_token)
+
+ @pytest.mark.asyncio
+ async def test_aiter_with_jsonrpc_message_no_meta(self, mock_wrapped_reader):
+ """Test handling JSONRPCMessage without _meta."""
+ mock_request = MagicMock(spec=JSONRPCRequest)
+ mock_request.params = {"other": "data"}
+
+ mock_message = MagicMock(spec=JSONRPCMessage)
+ mock_message.root = mock_request
+
+ async def async_iter():
+ for item in [mock_message]:
+ yield item
+
+ mock_wrapped_reader.__aiter__ = lambda self: async_iter()
+
+ reader = TransportContextExtractingReader(mock_wrapped_reader)
+
+ items = []
+ async for item in reader:
+ items.append(item)
+
+ assert len(items) == 1
+ assert items[0] == mock_message
+
+ @pytest.mark.asyncio
+ async def test_aiter_with_non_message_item(self, mock_wrapped_reader):
+ """Test handling non-message items."""
+ other_item = {"not": "a message"}
+
+ async def async_iter():
+ for item in [other_item]:
+ yield item
+
+ mock_wrapped_reader.__aiter__ = lambda self: async_iter()
+
+ reader = TransportContextExtractingReader(mock_wrapped_reader)
+
+ items = []
+ async for item in reader:
+ items.append(item)
+
+ assert len(items) == 1
+ assert items[0] == other_item
+
+
+class TestSessionContextSavingWriter:
+ @pytest.fixture
+ def mock_wrapped_writer(self):
+ """Create a mock wrapped writer."""
+ mock_writer = AsyncMock()
+ mock_writer.__aenter__ = AsyncMock(return_value=mock_writer)
+ mock_writer.__aexit__ = AsyncMock()
+ mock_writer.send = AsyncMock()
+ return mock_writer
+
+ def test_init(self, mock_wrapped_writer):
+ """Test writer initialization."""
+ writer = SessionContextSavingWriter(mock_wrapped_writer)
+ assert writer.__wrapped__ == mock_wrapped_writer
+
+ @pytest.mark.asyncio
+ async def test_context_manager_methods(self, mock_wrapped_writer):
+ """Test async context manager methods delegate correctly."""
+ writer = SessionContextSavingWriter(mock_wrapped_writer)
+
+ await writer.__aenter__()
+ mock_wrapped_writer.__aenter__.assert_called_once()
+
+ await writer.__aexit__(None, None, None)
+ mock_wrapped_writer.__aexit__.assert_called_once_with(None, None, None)
+
+ @pytest.mark.asyncio
+ async def test_send_wraps_item_with_context(self, mock_wrapped_writer):
+ """Test that send wraps items with current context."""
+ writer = SessionContextSavingWriter(mock_wrapped_writer)
+ test_item = {"test": "data"}
+
+ with patch.object(context, "get_current") as mock_get_current:
+ mock_context = MagicMock()
+ mock_get_current.return_value = mock_context
+
+ await writer.send(test_item)
+
+ mock_get_current.assert_called_once()
+ mock_wrapped_writer.send.assert_called_once()
+
+ # Verify the item was wrapped with context
+ sent_item = mock_wrapped_writer.send.call_args[0][0]
+ assert isinstance(sent_item, ItemWithContext)
+ assert sent_item.item == test_item
+ assert sent_item.ctx == mock_context
+
+
+class TestSessionContextAttachingReader:
+ @pytest.fixture
+ def mock_wrapped_reader(self):
+ """Create a mock wrapped reader."""
+ mock_reader = AsyncMock()
+ mock_reader.__aenter__ = AsyncMock(return_value=mock_reader)
+ mock_reader.__aexit__ = AsyncMock()
+ return mock_reader
+
+ def test_init(self, mock_wrapped_reader):
+ """Test reader initialization."""
+ reader = SessionContextAttachingReader(mock_wrapped_reader)
+ assert reader.__wrapped__ == mock_wrapped_reader
+
+ @pytest.mark.asyncio
+ async def test_context_manager_methods(self, mock_wrapped_reader):
+ """Test async context manager methods delegate correctly."""
+ reader = SessionContextAttachingReader(mock_wrapped_reader)
+
+ await reader.__aenter__()
+ mock_wrapped_reader.__aenter__.assert_called_once()
+
+ await reader.__aexit__(None, None, None)
+ mock_wrapped_reader.__aexit__.assert_called_once_with(None, None, None)
+
+ @pytest.mark.asyncio
+ async def test_aiter_with_item_with_context(self, mock_wrapped_reader):
+ """Test context restoration from ItemWithContext."""
+ test_item = {"test": "data"}
+ test_context = MagicMock()
+ wrapped_item = ItemWithContext(test_item, test_context)
+
+ async def async_iter():
+ for item in [wrapped_item]:
+ yield item
+
+ mock_wrapped_reader.__aiter__ = lambda self: async_iter()
+
+ reader = SessionContextAttachingReader(mock_wrapped_reader)
+
+ with patch.object(context, "attach") as mock_attach, patch.object(context, "detach") as mock_detach:
+ mock_token = MagicMock()
+ mock_attach.return_value = mock_token
+
+ items = []
+ async for item in reader:
+ items.append(item)
+
+ assert len(items) == 1
+ assert items[0] == test_item
+
+ mock_attach.assert_called_once_with(test_context)
+ mock_detach.assert_called_once_with(mock_token)
+
+ @pytest.mark.asyncio
+ async def test_aiter_with_regular_item(self, mock_wrapped_reader):
+ """Test handling regular items without context."""
+ regular_item = {"regular": "item"}
+
+ async def async_iter():
+ for item in [regular_item]:
+ yield item
+
+ mock_wrapped_reader.__aiter__ = lambda self: async_iter()
+
+ reader = SessionContextAttachingReader(mock_wrapped_reader)
+
+ items = []
+ async for item in reader:
+ items.append(item)
+
+ assert len(items) == 1
+ assert items[0] == regular_item
+
+
+# Mock Pydantic-like class for testing
+class MockPydanticParams:
+ """Mock class that behaves like a Pydantic model."""
+
+ def __init__(self, **data):
+ self._data = data
+
+ def model_dump(self):
+ return self._data.copy()
+
+ @classmethod
+ def model_validate(cls, data):
+ return cls(**data)
+
+ def __getattr__(self, name):
+ return self._data.get(name)
+
+
+class TestMCPInstrumentation:
+ def test_mcp_instrumentation_calls_wrap_function_wrapper(self):
+ """Test that mcp_instrumentation calls the expected wrapper functions."""
+ with (
+ patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap,
+ patch("strands.tools.mcp.mcp_instrumentation.register_post_import_hook") as mock_register,
+ ):
+ mcp_instrumentation()
+
+ # Verify wrap_function_wrapper was called for client patching
+ mock_wrap.assert_called_once_with(
+ "mcp.shared.session",
+ "BaseSession.send_request",
+ mock_wrap.call_args_list[0][0][2], # The patch function
+ )
+
+ # Verify register_post_import_hook was called for transport and session wrappers
+ assert mock_register.call_count == 2
+
+ # Check that the registered hooks are for the expected modules
+ registered_modules = [call[0][1] for call in mock_register.call_args_list]
+ assert "mcp.server.streamable_http" in registered_modules
+ assert "mcp.server.session" in registered_modules
+
+ def test_patch_mcp_client_injects_context_pydantic_model(self):
+ """Test that the client patch injects OpenTelemetry context into Pydantic models."""
+ # Create a mock request with tools/call method and Pydantic params
+ mock_request = MagicMock()
+ mock_request.root.method = "tools/call"
+
+ # Use our mock Pydantic-like class
+ mock_params = MockPydanticParams(existing="param")
+ mock_request.root.params = mock_params
+
+ # Create the patch function
+ with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap:
+ mcp_instrumentation()
+ patch_function = mock_wrap.call_args_list[0][0][2]
+
+ # Mock the wrapped function
+ mock_wrapped = MagicMock()
+
+ with patch.object(propagate, "get_global_textmap") as mock_textmap:
+ mock_textmap_instance = MagicMock()
+ mock_textmap.return_value = mock_textmap_instance
+
+ # Call the patch function
+ patch_function(mock_wrapped, None, [mock_request], {})
+
+ # Verify context was injected
+ mock_textmap_instance.inject.assert_called_once()
+ mock_wrapped.assert_called_once_with(mock_request)
+
+ # Verify the params object is still a MockPydanticParams (or dict if fallback occurred)
+ assert hasattr(mock_request.root.params, "model_dump") or isinstance(mock_request.root.params, dict)
+
+ def test_patch_mcp_client_injects_context_dict_params(self):
+ """Test that the client patch injects OpenTelemetry context into dict params."""
+ # Create a mock request with tools/call method and dict params
+ mock_request = MagicMock()
+ mock_request.root.method = "tools/call"
+ mock_request.root.params = {"existing": "param"}
+
+ # Create the patch function
+ with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap:
+ mcp_instrumentation()
+ patch_function = mock_wrap.call_args_list[0][0][2]
+
+ # Mock the wrapped function
+ mock_wrapped = MagicMock()
+
+ with patch.object(propagate, "get_global_textmap") as mock_textmap:
+ mock_textmap_instance = MagicMock()
+ mock_textmap.return_value = mock_textmap_instance
+
+ # Call the patch function
+ patch_function(mock_wrapped, None, [mock_request], {})
+
+ # Verify context was injected
+ mock_textmap_instance.inject.assert_called_once()
+ mock_wrapped.assert_called_once_with(mock_request)
+
+ # Verify _meta was added to the params dict
+ assert "_meta" in mock_request.root.params
+
+ def test_patch_mcp_client_skips_non_tools_call(self):
+ """Test that the client patch skips non-tools/call methods."""
+ mock_request = MagicMock()
+ mock_request.root.method = "other/method"
+
+ with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap:
+ mcp_instrumentation()
+ patch_function = mock_wrap.call_args_list[0][0][2]
+
+ mock_wrapped = MagicMock()
+
+ with patch.object(propagate, "get_global_textmap") as mock_textmap:
+ mock_textmap_instance = MagicMock()
+ mock_textmap.return_value = mock_textmap_instance
+
+ patch_function(mock_wrapped, None, [mock_request], {})
+
+ # Verify context injection was skipped
+ mock_textmap_instance.inject.assert_not_called()
+ mock_wrapped.assert_called_once_with(mock_request)
+
+ def test_patch_mcp_client_handles_exception_gracefully(self):
+ """Test that the client patch handles exceptions gracefully."""
+ # Create a mock request that will cause an exception
+ mock_request = MagicMock()
+ mock_request.root.method = "tools/call"
+ mock_request.root.params = MagicMock()
+ mock_request.root.params.model_dump.side_effect = Exception("Test exception")
+
+ with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap:
+ mcp_instrumentation()
+ patch_function = mock_wrap.call_args_list[0][0][2]
+
+ mock_wrapped = MagicMock()
+
+ # Should not raise an exception, should call wrapped function normally
+ patch_function(mock_wrapped, None, [mock_request], {})
+ mock_wrapped.assert_called_once_with(mock_request)
+
+ def test_patch_mcp_client_pydantic_fallback_to_dict(self):
+ """Test that Pydantic model recreation falls back to dict on failure."""
+
+ # Create a Pydantic-like class that fails on model_validate
+ class FailingMockPydanticParams:
+ def __init__(self, **data):
+ self._data = data
+
+ def model_dump(self):
+ return self._data.copy()
+
+ def model_validate(self, data):
+ raise Exception("Reconstruction failed")
+
+ # Create a mock request with failing Pydantic params
+ mock_request = MagicMock()
+ mock_request.root.method = "tools/call"
+
+ failing_params = FailingMockPydanticParams(existing="param")
+ mock_request.root.params = failing_params
+
+ with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap:
+ mcp_instrumentation()
+ patch_function = mock_wrap.call_args_list[0][0][2]
+
+ mock_wrapped = MagicMock()
+
+ with patch.object(propagate, "get_global_textmap") as mock_textmap:
+ mock_textmap_instance = MagicMock()
+ mock_textmap.return_value = mock_textmap_instance
+
+ # Call the patch function
+ patch_function(mock_wrapped, None, [mock_request], {})
+
+ # Verify it fell back to dict
+ assert isinstance(mock_request.root.params, dict)
+ assert "_meta" in mock_request.root.params
+ mock_wrapped.assert_called_once_with(mock_request)
From 09ca806adf2f7efa367c812514435eb0089dcd0a Mon Sep 17 00:00:00 2001
From: Vince Mi
Date: Tue, 5 Aug 2025 10:32:30 -0700
Subject: [PATCH 021/104] Change max_tokens type to int to match Anthropic API
(#588)
Using a string causes the Anthropic API call to fail:
```
anthropic.BadRequestError: Error code: 400 - {'type': 'error', 'error': {'type': 'invalid_request_error', 'message': 'max_tokens: Input should be a valid integer'}}
```
---
src/strands/models/anthropic.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py
index 975fca3e9..29cb40d40 100644
--- a/src/strands/models/anthropic.py
+++ b/src/strands/models/anthropic.py
@@ -55,7 +55,7 @@ class AnthropicConfig(TypedDict, total=False):
For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages.
"""
- max_tokens: Required[str]
+ max_tokens: Required[int]
model_id: Required[str]
params: Optional[dict[str, Any]]
From bf24ebf4d479cedea3f74452d8309142232203f9 Mon Sep 17 00:00:00 2001
From: mehtarac
Date: Wed, 6 Aug 2025 09:40:42 -0400
Subject: [PATCH 022/104] feat: Add additional intructions for contributors to
find issues that are ready to be worked on (#595)
---
CONTRIBUTING.md | 21 ++++++++++++---------
1 file changed, 12 insertions(+), 9 deletions(-)
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index fa724cddc..add4825fd 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -25,6 +25,17 @@ Please try to include as much information as you can. Details like these are inc
* Anything unusual about your environment or deployment
+## Finding contributions to work on
+Looking at the existing issues is a great way to find something to contribute to. We label issues that are well-defined and ready for community contributions with the "ready for contribution" label.
+
+Check our [Ready for Contribution](../../issues?q=is%3Aissue%20state%3Aopen%20label%3A%22ready%20for%20contribution%22) issues for items you can work on.
+
+Before starting work on any issue:
+1. Check if someone is already assigned or working on it
+2. Comment on the issue to express your interest and ask any clarifying questions
+3. Wait for maintainer confirmation before beginning significant work
+
+
## Development Environment
This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as the build backend and [hatch](https://hatch.pypa.io/latest/) for development workflow management.
@@ -70,7 +81,7 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as
### Pre-commit Hooks
-We use [pre-commit](https://pre-commit.com/) to automatically run quality checks before each commit. The hook will run `hatch run format`, `hatch run lint`, `hatch run test`, and `hatch run cz check` on when you make a commit, ensuring code consistency.
+We use [pre-commit](https://pre-commit.com/) to automatically run quality checks before each commit. The hook will run `hatch run format`, `hatch run lint`, `hatch run test`, and `hatch run cz check` when you make a commit, ensuring code consistency.
The pre-commit hook is installed with:
@@ -122,14 +133,6 @@ To send us a pull request, please:
8. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
-## Finding contributions to work on
-Looking at the existing issues is a great way to find something to contribute to.
-
-You can check:
-- Our known bugs list in [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for issues that need fixing
-- Feature requests in [Feature Requests](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Aenhancement) for new functionality to implement
-
-
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
From 297ec5cdfcd4b1e6e178429ef657911654e865b0 Mon Sep 17 00:00:00 2001
From: Jeremiah
Date: Wed, 6 Aug 2025 09:54:49 -0400
Subject: [PATCH 023/104] feat(a2a): configurable request handler (#601)
Co-authored-by: jer
---
src/strands/multiagent/a2a/server.py | 22 ++++++++++++++++++++--
1 file changed, 20 insertions(+), 2 deletions(-)
diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py
index fa7b6b887..35ea5b2e3 100644
--- a/src/strands/multiagent/a2a/server.py
+++ b/src/strands/multiagent/a2a/server.py
@@ -10,8 +10,9 @@
import uvicorn
from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication
+from a2a.server.events import QueueManager
from a2a.server.request_handlers import DefaultRequestHandler
-from a2a.server.tasks import InMemoryTaskStore
+from a2a.server.tasks import InMemoryTaskStore, PushNotificationConfigStore, PushNotificationSender, TaskStore
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
from fastapi import FastAPI
from starlette.applications import Starlette
@@ -36,6 +37,12 @@ def __init__(
serve_at_root: bool = False,
version: str = "0.0.1",
skills: list[AgentSkill] | None = None,
+ # RequestHandler
+ task_store: TaskStore | None = None,
+ queue_manager: QueueManager | None = None,
+ push_config_store: PushNotificationConfigStore | None = None,
+ push_sender: PushNotificationSender | None = None,
+
):
"""Initialize an A2A-compatible server from a Strands agent.
@@ -52,6 +59,14 @@ def __init__(
Defaults to False.
version: The version of the agent. Defaults to "0.0.1".
skills: The list of capabilities or functions the agent can perform.
+ task_store: Custom task store implementation for managing agent tasks. If None,
+ uses InMemoryTaskStore.
+ queue_manager: Custom queue manager for handling message queues. If None,
+ no queue management is used.
+ push_config_store: Custom store for push notification configurations. If None,
+ no push notification configuration is used.
+ push_sender: Custom push notification sender implementation. If None,
+ no push notifications are sent.
"""
self.host = host
self.port = port
@@ -77,7 +92,10 @@ def __init__(
self.capabilities = AgentCapabilities(streaming=True)
self.request_handler = DefaultRequestHandler(
agent_executor=StrandsA2AExecutor(self.strands_agent),
- task_store=InMemoryTaskStore(),
+ task_store=task_store or InMemoryTaskStore(),
+ queue_manager=queue_manager,
+ push_config_store=push_config_store,
+ push_sender=push_sender,
)
self._agent_skills = skills
logger.info("Strands' integration with A2A is experimental. Be aware of frequent breaking changes.")
From ec5304c39809b99b1d29ed03ce7ae40536575e95 Mon Sep 17 00:00:00 2001
From: Jeremiah
Date: Wed, 6 Aug 2025 12:48:13 -0400
Subject: [PATCH 024/104] chore(a2a): update host per AppSec recommendation
(#619)
Co-authored-by: jer
---
src/strands/multiagent/a2a/server.py | 5 ++---
tests/strands/multiagent/a2a/test_server.py | 10 +++++-----
2 files changed, 7 insertions(+), 8 deletions(-)
diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py
index 35ea5b2e3..bbfbc824d 100644
--- a/src/strands/multiagent/a2a/server.py
+++ b/src/strands/multiagent/a2a/server.py
@@ -31,7 +31,7 @@ def __init__(
agent: SAAgent,
*,
# AgentCard
- host: str = "0.0.0.0",
+ host: str = "127.0.0.1",
port: int = 9000,
http_url: str | None = None,
serve_at_root: bool = False,
@@ -42,13 +42,12 @@ def __init__(
queue_manager: QueueManager | None = None,
push_config_store: PushNotificationConfigStore | None = None,
push_sender: PushNotificationSender | None = None,
-
):
"""Initialize an A2A-compatible server from a Strands agent.
Args:
agent: The Strands Agent to wrap with A2A compatibility.
- host: The hostname or IP address to bind the A2A server to. Defaults to "0.0.0.0".
+ host: The hostname or IP address to bind the A2A server to. Defaults to "127.0.0.1".
port: The port to bind the A2A server to. Defaults to 9000.
http_url: The public HTTP URL where this agent will be accessible. If provided,
this overrides the generated URL from host/port and enables automatic
diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py
index a3b47581c..00dd164b5 100644
--- a/tests/strands/multiagent/a2a/test_server.py
+++ b/tests/strands/multiagent/a2a/test_server.py
@@ -22,9 +22,9 @@ def test_a2a_agent_initialization(mock_strands_agent):
assert a2a_agent.strands_agent == mock_strands_agent
assert a2a_agent.name == "Test Agent"
assert a2a_agent.description == "A test agent for unit testing"
- assert a2a_agent.host == "0.0.0.0"
+ assert a2a_agent.host == "127.0.0.1"
assert a2a_agent.port == 9000
- assert a2a_agent.http_url == "http://0.0.0.0:9000/"
+ assert a2a_agent.http_url == "http://127.0.0.1:9000/"
assert a2a_agent.version == "0.0.1"
assert isinstance(a2a_agent.capabilities, AgentCapabilities)
assert len(a2a_agent.agent_skills) == 1
@@ -85,7 +85,7 @@ def test_public_agent_card(mock_strands_agent):
assert isinstance(card, AgentCard)
assert card.name == "Test Agent"
assert card.description == "A test agent for unit testing"
- assert card.url == "http://0.0.0.0:9000/"
+ assert card.url == "http://127.0.0.1:9000/"
assert card.version == "0.0.1"
assert card.default_input_modes == ["text"]
assert card.default_output_modes == ["text"]
@@ -448,7 +448,7 @@ def test_serve_with_starlette(mock_run, mock_strands_agent):
mock_run.assert_called_once()
args, kwargs = mock_run.call_args
assert isinstance(args[0], Starlette)
- assert kwargs["host"] == "0.0.0.0"
+ assert kwargs["host"] == "127.0.0.1"
assert kwargs["port"] == 9000
@@ -462,7 +462,7 @@ def test_serve_with_fastapi(mock_run, mock_strands_agent):
mock_run.assert_called_once()
args, kwargs = mock_run.call_args
assert isinstance(args[0], FastAPI)
- assert kwargs["host"] == "0.0.0.0"
+ assert kwargs["host"] == "127.0.0.1"
assert kwargs["port"] == 9000
From 29b21278f5816ffa01dbb555bb6ff192ae105d59 Mon Sep 17 00:00:00 2001
From: Dean Schmigelski
Date: Fri, 8 Aug 2025 10:43:34 -0400
Subject: [PATCH 025/104] fix(event_loop): ensure tool_use content blocks are
valid after max_tokens to prevent unrecoverable state (#607)
---
.../_recover_message_on_max_tokens_reached.py | 71 +++++
src/strands/event_loop/event_loop.py | 32 ++-
src/strands/types/exceptions.py | 6 +-
tests/strands/event_loop/test_event_loop.py | 55 ++--
...t_recover_message_on_max_tokens_reached.py | 269 ++++++++++++++++++
tests_integ/test_max_tokens_reached.py | 32 ++-
6 files changed, 420 insertions(+), 45 deletions(-)
create mode 100644 src/strands/event_loop/_recover_message_on_max_tokens_reached.py
create mode 100644 tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py
diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py
new file mode 100644
index 000000000..ab6fb4abe
--- /dev/null
+++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py
@@ -0,0 +1,71 @@
+"""Message recovery utilities for handling max token limit scenarios.
+
+This module provides functionality to recover and clean up incomplete messages that occur
+when model responses are truncated due to maximum token limits being reached. It specifically
+handles cases where tool use blocks are incomplete or malformed due to truncation.
+"""
+
+import logging
+
+from ..types.content import ContentBlock, Message
+from ..types.tools import ToolUse
+
+logger = logging.getLogger(__name__)
+
+
+def recover_message_on_max_tokens_reached(message: Message) -> Message:
+ """Recover and clean up messages when max token limits are reached.
+
+ When a model response is truncated due to maximum token limits, all tool use blocks
+ should be replaced with informative error messages since they may be incomplete or
+ unreliable. This function inspects the message content and:
+
+ 1. Identifies all tool use blocks (regardless of validity)
+ 2. Replaces all tool uses with informative error messages
+ 3. Preserves all non-tool content blocks (text, images, etc.)
+ 4. Returns a cleaned message suitable for conversation history
+
+ This recovery mechanism ensures that the conversation can continue gracefully even when
+ model responses are truncated, providing clear feedback about what happened and preventing
+ potentially incomplete or corrupted tool executions.
+
+ Args:
+ message: The potentially incomplete message from the model that was truncated
+ due to max token limits.
+
+ Returns:
+ A cleaned Message with all tool uses replaced by explanatory text content.
+ The returned message maintains the same role as the input message.
+
+ Example:
+ If a message contains any tool use (complete or incomplete):
+ ```
+ {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}
+ ```
+
+ It will be replaced with:
+ ```
+ {"text": "The selected tool calculator's tool use was incomplete due to maximum token limits being reached."}
+ ```
+ """
+ logger.info("handling max_tokens stop reason - replacing all tool uses with error messages")
+
+ valid_content: list[ContentBlock] = []
+ for content in message["content"] or []:
+ tool_use: ToolUse | None = content.get("toolUse")
+ if not tool_use:
+ valid_content.append(content)
+ continue
+
+ # Replace all tool uses with error messages when max_tokens is reached
+ display_name = tool_use.get("name") or ""
+ logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name)
+
+ valid_content.append(
+ {
+ "text": f"The selected tool {display_name}'s tool use was incomplete due "
+ f"to maximum token limits being reached."
+ }
+ )
+
+ return {"content": valid_content, "role": message["role"]}
diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py
index ae21d4c6d..b36f73155 100644
--- a/src/strands/event_loop/event_loop.py
+++ b/src/strands/event_loop/event_loop.py
@@ -36,6 +36,7 @@
)
from ..types.streaming import Metrics, StopReason
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
+from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached
from .streaming import stream_messages
if TYPE_CHECKING:
@@ -156,6 +157,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
)
)
+ if stop_reason == "max_tokens":
+ message = recover_message_on_max_tokens_reached(message)
+
if model_invoke_span:
tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason)
break # Success! Break out of retry loop
@@ -192,6 +196,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
raise e
try:
+ # Add message in trace and mark the end of the stream messages trace
+ stream_trace.add_message(message)
+ stream_trace.end()
+
+ # Add the response message to the conversation
+ agent.messages.append(message)
+ agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
+ yield {"callback": {"message": message}}
+
+ # Update metrics
+ agent.event_loop_metrics.update_usage(usage)
+ agent.event_loop_metrics.update_metrics(metrics)
+
if stop_reason == "max_tokens":
"""
Handle max_tokens limit reached by the model.
@@ -205,21 +222,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
"Agent has reached an unrecoverable state due to max_tokens limit. "
"For more information see: "
"https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
- ),
- incomplete_message=message,
+ )
)
- # Add message in trace and mark the end of the stream messages trace
- stream_trace.add_message(message)
- stream_trace.end()
-
- # Add the response message to the conversation
- agent.messages.append(message)
- agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
- yield {"callback": {"message": message}}
-
- # Update metrics
- agent.event_loop_metrics.update_usage(usage)
- agent.event_loop_metrics.update_metrics(metrics)
# If the model is requesting to use tools
if stop_reason == "tool_use":
diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py
index 71ea28b9f..90f2b8d7f 100644
--- a/src/strands/types/exceptions.py
+++ b/src/strands/types/exceptions.py
@@ -2,8 +2,6 @@
from typing import Any
-from strands.types.content import Message
-
class EventLoopException(Exception):
"""Exception raised by the event loop."""
@@ -28,14 +26,12 @@ class MaxTokensReachedException(Exception):
the complexity of the response, or when the model naturally reaches its configured output limit during generation.
"""
- def __init__(self, message: str, incomplete_message: Message):
+ def __init__(self, message: str):
"""Initialize the exception with an error message and the incomplete message object.
Args:
message: The error message describing the token limit issue
- incomplete_message: The valid Message object with incomplete content due to token limits
"""
- self.incomplete_message = incomplete_message
super().__init__(message)
diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py
index 3886df8b9..191ab51ba 100644
--- a/tests/strands/event_loop/test_event_loop.py
+++ b/tests/strands/event_loop/test_event_loop.py
@@ -305,8 +305,10 @@ async def test_event_loop_cycle_text_response_error(
await alist(stream)
+@patch("strands.event_loop.event_loop.recover_message_on_max_tokens_reached")
@pytest.mark.asyncio
async def test_event_loop_cycle_tool_result(
+ mock_recover_message,
agent,
model,
system_prompt,
@@ -339,6 +341,9 @@ async def test_event_loop_cycle_tool_result(
assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state
+ # Verify that recover_message_on_max_tokens_reached was NOT called for tool_use stop reason
+ mock_recover_message.assert_not_called()
+
model.stream.assert_called_with(
[
{"role": "user", "content": [{"text": "Hello"}]},
@@ -568,25 +573,35 @@ async def test_event_loop_cycle_max_tokens_exception(
agenerator,
alist,
):
- """Test that max_tokens stop reason raises MaxTokensReachedException."""
+ """Test that max_tokens stop reason calls _recover_message_on_max_tokens_reached then MaxTokensReachedException."""
- # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495
- model.stream.return_value = agenerator(
- [
- {
- "contentBlockStart": {
- "start": {
- "toolUse": {},
+ model.stream.side_effect = [
+ agenerator(
+ [
+ {
+ "contentBlockStart": {
+ "start": {
+ "toolUse": {
+ "toolUseId": "t1",
+ "name": "asdf",
+ "input": {}, # empty
+ },
+ },
},
},
- },
- {"contentBlockStop": {}},
- {"messageStop": {"stopReason": "max_tokens"}},
- ]
- )
+ {"contentBlockStop": {}},
+ {"messageStop": {"stopReason": "max_tokens"}},
+ ]
+ ),
+ ]
# Call event_loop_cycle, expecting it to raise MaxTokensReachedException
- with pytest.raises(MaxTokensReachedException) as exc_info:
+ expected_message = (
+ "Agent has reached an unrecoverable state due to max_tokens limit. "
+ "For more information see: "
+ "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
+ )
+ with pytest.raises(MaxTokensReachedException, match=expected_message):
stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
@@ -594,16 +609,8 @@ async def test_event_loop_cycle_max_tokens_exception(
await alist(stream)
# Verify the exception message contains the expected content
- expected_message = (
- "Agent has reached an unrecoverable state due to max_tokens limit. "
- "For more information see: "
- "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
- )
- assert str(exc_info.value) == expected_message
-
- # Verify that the message has not been appended to the messages array
- assert len(agent.messages) == 1
- assert exc_info.value.incomplete_message not in agent.messages
+ assert len(agent.messages) == 2
+ assert "tool use was incomplete due" in agent.messages[1]["content"][0]["text"]
@patch("strands.event_loop.event_loop.get_tracer")
diff --git a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py
new file mode 100644
index 000000000..402e90966
--- /dev/null
+++ b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py
@@ -0,0 +1,269 @@
+"""Tests for token limit recovery utility."""
+
+from strands.event_loop._recover_message_on_max_tokens_reached import (
+ recover_message_on_max_tokens_reached,
+)
+from strands.types.content import Message
+
+
+def test_recover_message_on_max_tokens_reached_with_incomplete_tool_use():
+ """Test recovery when incomplete tool use is present in the message."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"text": "I'll help you with that."},
+ {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 2
+
+ # First content block should be preserved
+ assert result["content"][0] == {"text": "I'll help you with that."}
+
+ # Second content block should be replaced with error message
+ assert "text" in result["content"][1]
+ assert "calculator" in result["content"][1]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][1]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_with_missing_tool_name():
+ """Test recovery when tool use has no name."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 1
+
+ # Content should be replaced with error message using
+ assert "text" in result["content"][0]
+ assert "" in result["content"][0]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][0]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_with_missing_input():
+ """Test recovery when tool use has no input."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"toolUse": {"name": "calculator", "toolUseId": "123"}}, # Missing input
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 1
+
+ # Content should be replaced with error message
+ assert "text" in result["content"][0]
+ assert "calculator" in result["content"][0]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][0]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_with_missing_tool_use_id():
+ """Test recovery when tool use has no toolUseId."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}}}, # Missing toolUseId
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 1
+
+ # Content should be replaced with error message
+ assert "text" in result["content"][0]
+ assert "calculator" in result["content"][0]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][0]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_with_valid_tool_use():
+ """Test that even valid tool uses are replaced with error messages."""
+ complete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"text": "I'll help you with that."},
+ {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}, # Valid
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(complete_message)
+
+ # Should replace even valid tool uses with error messages
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 2
+ assert result["content"][0] == {"text": "I'll help you with that."}
+
+ # Valid tool use should also be replaced with error message
+ assert "text" in result["content"][1]
+ assert "calculator" in result["content"][1]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][1]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_with_empty_content():
+ """Test handling of message with empty content."""
+ empty_message: Message = {"role": "assistant", "content": []}
+
+ result = recover_message_on_max_tokens_reached(empty_message)
+
+ # Should return message with empty content preserved
+ assert result["role"] == "assistant"
+ assert result["content"] == []
+
+
+def test_recover_message_on_max_tokens_reached_with_none_content():
+ """Test handling of message with None content."""
+ none_content_message: Message = {"role": "assistant", "content": None}
+
+ result = recover_message_on_max_tokens_reached(none_content_message)
+
+ # Should return message with empty content
+ assert result["role"] == "assistant"
+ assert result["content"] == []
+
+
+def test_recover_message_on_max_tokens_reached_with_mixed_content():
+ """Test recovery with mix of valid content and incomplete tool use."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"text": "Let me calculate this for you."},
+ {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Incomplete
+ {"text": "And then I'll explain the result."},
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 3
+
+ # First and third content blocks should be preserved
+ assert result["content"][0] == {"text": "Let me calculate this for you."}
+ assert result["content"][2] == {"text": "And then I'll explain the result."}
+
+ # Second content block should be replaced with error message
+ assert "text" in result["content"][1]
+ assert "calculator" in result["content"][1]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][1]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_preserves_non_tool_content():
+ """Test that non-tool content is preserved as-is."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"text": "Here's some text."},
+ {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}},
+ {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Incomplete
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 3
+
+ # First two content blocks should be preserved exactly
+ assert result["content"][0] == {"text": "Here's some text."}
+ assert result["content"][1] == {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}}
+
+ # Third content block should be replaced with error message
+ assert "text" in result["content"][2]
+ assert "" in result["content"][2]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][2]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_multiple_incomplete_tools():
+ """Test recovery with multiple incomplete tool uses."""
+ incomplete_message: Message = {
+ "role": "assistant",
+ "content": [
+ {"toolUse": {"name": "calculator", "input": {}}}, # Missing toolUseId
+ {"text": "Some text in between."},
+ {"toolUse": {"name": "", "input": {}, "toolUseId": "456"}}, # Missing name
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 3
+
+ # First tool use should be replaced
+ assert "text" in result["content"][0]
+ assert "calculator" in result["content"][0]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][0]["text"]
+
+ # Text content should be preserved
+ assert result["content"][1] == {"text": "Some text in between."}
+
+ # Second tool use should be replaced with
+ assert "text" in result["content"][2]
+ assert "" in result["content"][2]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][2]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_preserves_user_role():
+ """Test that the function preserves the original message role."""
+ incomplete_message: Message = {
+ "role": "user",
+ "content": [
+ {"toolUse": {"name": "calculator", "input": {}}}, # Missing toolUseId
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(incomplete_message)
+
+ # Should preserve the original role
+ assert result["role"] == "user"
+ assert len(result["content"]) == 1
+ assert "text" in result["content"][0]
+ assert "calculator" in result["content"][0]["text"]
+
+
+def test_recover_message_on_max_tokens_reached_with_content_without_tool_use():
+ """Test handling of content blocks that don't have toolUse key."""
+ message: Message = {
+ "role": "assistant",
+ "content": [
+ {"text": "Regular text content."},
+ {"someOtherKey": "someValue"}, # Content without toolUse
+ {"toolUse": {"name": "calculator"}}, # Incomplete tool use
+ ],
+ }
+
+ result = recover_message_on_max_tokens_reached(message)
+
+ # Check the corrected message content
+ assert result["role"] == "assistant"
+ assert len(result["content"]) == 3
+
+ # First two content blocks should be preserved
+ assert result["content"][0] == {"text": "Regular text content."}
+ assert result["content"][1] == {"someOtherKey": "someValue"}
+
+ # Third content block should be replaced with error message
+ assert "text" in result["content"][2]
+ assert "calculator" in result["content"][2]["text"]
+ assert "incomplete due to maximum token limits" in result["content"][2]["text"]
diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py
index d9c2817b3..bf5668349 100644
--- a/tests_integ/test_max_tokens_reached.py
+++ b/tests_integ/test_max_tokens_reached.py
@@ -1,20 +1,48 @@
+import logging
+
import pytest
+from src.strands.agent import AgentResult
from strands import Agent, tool
from strands.models.bedrock import BedrockModel
from strands.types.exceptions import MaxTokensReachedException
+logger = logging.getLogger(__name__)
+
@tool
def story_tool(story: str) -> str:
+ """
+ Tool that writes a story that is minimum 50,000 lines long.
+ """
return story
-def test_context_window_overflow():
+def test_max_tokens_reached():
+ """Test that MaxTokensReachedException is raised but the agent can still rerun on the second pass"""
model = BedrockModel(max_tokens=100)
agent = Agent(model=model, tools=[story_tool])
+ # This should raise an exception
with pytest.raises(MaxTokensReachedException):
agent("Tell me a story!")
- assert len(agent.messages) == 1
+ # Validate that at least one message contains the incomplete tool use error message
+ expected_text = "tool use was incomplete due to maximum token limits being reached"
+ all_text_content = [
+ content_block["text"]
+ for message in agent.messages
+ for content_block in message.get("content", [])
+ if "text" in content_block
+ ]
+
+ assert any(expected_text in text for text in all_text_content), (
+ f"Expected to find message containing '{expected_text}' in agent messages"
+ )
+
+ # Remove tools from agent and re-run with a generic question
+ agent.tool_registry.registry = {}
+ agent.tool_registry.tool_config = {}
+
+ result: AgentResult = agent("What is 3+3")
+ assert result.stop_reason == "end_turn"
From adac26f15930fe2fc6754f5f9ddeab2ff9698463 Mon Sep 17 00:00:00 2001
From: Dean Schmigelski
Date: Fri, 8 Aug 2025 10:44:39 -0400
Subject: [PATCH 026/104] fix(structured_output): do not modify
conversation_history when prompt is passed (#628)
---
src/strands/agent/agent.py | 20 +++++-----
tests/strands/agent/test_agent.py | 52 +++++++++++++++++++++++++
tests/strands/agent/test_agent_hooks.py | 10 ++---
3 files changed, 67 insertions(+), 15 deletions(-)
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index 111509e3a..2022142c6 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -403,8 +403,8 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T:
"""This method allows you to get structured output from the agent.
- If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
- If you don't pass in a prompt, it will use only the conversation history to respond.
+ If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
+ If you don't pass in a prompt, it will use only the existing conversation history to respond.
For smaller models, you may want to use the optional prompt to add additional instructions to explicitly
instruct the model to output the structured data.
@@ -412,7 +412,7 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l
Args:
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
that the agent will use when responding.
- prompt: The prompt to use for the agent.
+ prompt: The prompt to use for the agent (will not be added to conversation history).
Raises:
ValueError: If no conversation history or prompt is provided.
@@ -430,8 +430,8 @@ async def structured_output_async(
) -> T:
"""This method allows you to get structured output from the agent.
- If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
- If you don't pass in a prompt, it will use only the conversation history to respond.
+ If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
+ If you don't pass in a prompt, it will use only the existing conversation history to respond.
For smaller models, you may want to use the optional prompt to add additional instructions to explicitly
instruct the model to output the structured data.
@@ -439,7 +439,7 @@ async def structured_output_async(
Args:
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
that the agent will use when responding.
- prompt: The prompt to use for the agent.
+ prompt: The prompt to use for the agent (will not be added to conversation history).
Raises:
ValueError: If no conversation history or prompt is provided.
@@ -450,12 +450,14 @@ async def structured_output_async(
if not self.messages and not prompt:
raise ValueError("No conversation history or prompt provided")
- # add the prompt as the last message
+ # Create temporary messages array if prompt is provided
if prompt:
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
- self._append_message({"role": "user", "content": content})
+ temp_messages = self.messages + [{"role": "user", "content": content}]
+ else:
+ temp_messages = self.messages
- events = self.model.structured_output(output_model, self.messages, system_prompt=self.system_prompt)
+ events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
async for event in events:
if "callback" in event:
self.callback_handler(**cast(dict, event["callback"]))
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index 4e310dace..c27243dfe 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -984,10 +984,17 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator):
prompt = "Jane Doe is 30 years old and her email is jane@doe.com"
+ # Store initial message count
+ initial_message_count = len(agent.messages)
+
tru_result = agent.structured_output(type(user), prompt)
exp_result = user
assert tru_result == exp_result
+ # Verify conversation history is not polluted
+ assert len(agent.messages) == initial_message_count
+
+ # Verify the model was called with temporary messages array
agent.model.structured_output.assert_called_once_with(
type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt
)
@@ -1008,10 +1015,17 @@ def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, a
},
]
+ # Store initial message count
+ initial_message_count = len(agent.messages)
+
tru_result = agent.structured_output(type(user), prompt)
exp_result = user
assert tru_result == exp_result
+ # Verify conversation history is not polluted
+ assert len(agent.messages) == initial_message_count
+
+ # Verify the model was called with temporary messages array
agent.model.structured_output.assert_called_once_with(
type(user), [{"role": "user", "content": prompt}], system_prompt=system_prompt
)
@@ -1023,10 +1037,41 @@ async def test_agent_structured_output_in_async_context(agent, user, agenerator)
prompt = "Jane Doe is 30 years old and her email is jane@doe.com"
+ # Store initial message count
+ initial_message_count = len(agent.messages)
+
tru_result = await agent.structured_output_async(type(user), prompt)
exp_result = user
assert tru_result == exp_result
+ # Verify conversation history is not polluted
+ assert len(agent.messages) == initial_message_count
+
+
+def test_agent_structured_output_without_prompt(agent, system_prompt, user, agenerator):
+ """Test that structured_output works with existing conversation history and no new prompt."""
+ agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))
+
+ # Add some existing messages to the agent
+ existing_messages = [
+ {"role": "user", "content": [{"text": "Jane Doe is 30 years old"}]},
+ {"role": "assistant", "content": [{"text": "I understand."}]},
+ ]
+ agent.messages.extend(existing_messages)
+
+ initial_message_count = len(agent.messages)
+
+ tru_result = agent.structured_output(type(user)) # No prompt provided
+ exp_result = user
+ assert tru_result == exp_result
+
+ # Verify conversation history is unchanged
+ assert len(agent.messages) == initial_message_count
+ assert agent.messages == existing_messages
+
+ # Verify the model was called with existing messages only
+ agent.model.structured_output.assert_called_once_with(type(user), existing_messages, system_prompt=system_prompt)
+
@pytest.mark.asyncio
async def test_agent_structured_output_async(agent, system_prompt, user, agenerator):
@@ -1034,10 +1079,17 @@ async def test_agent_structured_output_async(agent, system_prompt, user, agenera
prompt = "Jane Doe is 30 years old and her email is jane@doe.com"
+ # Store initial message count
+ initial_message_count = len(agent.messages)
+
tru_result = agent.structured_output(type(user), prompt)
exp_result = user
assert tru_result == exp_result
+ # Verify conversation history is not polluted
+ assert len(agent.messages) == initial_message_count
+
+ # Verify the model was called with temporary messages array
agent.model.structured_output.assert_called_once_with(
type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt
)
diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py
index cd89fbc7a..9ab008ca2 100644
--- a/tests/strands/agent/test_agent_hooks.py
+++ b/tests/strands/agent/test_agent_hooks.py
@@ -267,13 +267,12 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
length, events = hook_provider.get_events()
- assert length == 3
+ assert length == 2
assert next(events) == BeforeInvocationEvent(agent=agent)
- assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0])
assert next(events) == AfterInvocationEvent(agent=agent)
- assert len(agent.messages) == 1
+ assert len(agent.messages) == 0 # no new messages added
@pytest.mark.asyncio
@@ -285,10 +284,9 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a
length, events = hook_provider.get_events()
- assert length == 3
+ assert length == 2
assert next(events) == BeforeInvocationEvent(agent=agent)
- assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0])
assert next(events) == AfterInvocationEvent(agent=agent)
- assert len(agent.messages) == 1
+ assert len(agent.messages) == 0 # no new messages added
From 99963b64c261431c6f10c31853a2dcc667a9ebbb Mon Sep 17 00:00:00 2001
From: Murat Kaan Meral
Date: Mon, 11 Aug 2025 17:11:27 +0200
Subject: [PATCH 027/104] feature(graph): Allow cyclic graphs (#497)
---
.gitignore | 3 +-
src/strands/multiagent/graph.py | 304 ++++++++++---
tests/strands/multiagent/test_graph.py | 579 ++++++++++++++++++++++++-
3 files changed, 805 insertions(+), 81 deletions(-)
diff --git a/.gitignore b/.gitignore
index cb34b9150..c27d1d902 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,4 +9,5 @@ __pycache__*
*.bak
.vscode
dist
-repl_state
\ No newline at end of file
+repl_state
+.kiro
\ No newline at end of file
diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py
index cbba0fecf..9aee260b1 100644
--- a/src/strands/multiagent/graph.py
+++ b/src/strands/multiagent/graph.py
@@ -1,31 +1,33 @@
-"""Directed Acyclic Graph (DAG) Multi-Agent Pattern Implementation.
+"""Directed Graph Multi-Agent Pattern Implementation.
-This module provides a deterministic DAG-based agent orchestration system where
+This module provides a deterministic graph-based agent orchestration system where
agents or MultiAgentBase instances (like Swarm or Graph) are nodes in a graph,
executed according to edge dependencies, with output from one node passed as input
to connected nodes.
Key Features:
- Agents and MultiAgentBase instances (Swarm, Graph, etc.) as graph nodes
-- Deterministic execution order based on DAG structure
+- Deterministic execution based on dependency resolution
- Output propagation along edges
-- Topological sort for execution ordering
+- Support for cyclic graphs (feedback loops)
- Clear dependency management
- Supports nested graphs (Graph as a node in another Graph)
"""
import asyncio
+import copy
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
-from typing import Any, Callable, Tuple
+from typing import Any, Callable, Optional, Tuple
from opentelemetry import trace as trace_api
from ..agent import Agent
+from ..agent.state import AgentState
from ..telemetry import get_tracer
-from ..types.content import ContentBlock
+from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
@@ -54,6 +56,7 @@ class GraphState:
completed_nodes: set["GraphNode"] = field(default_factory=set)
failed_nodes: set["GraphNode"] = field(default_factory=set)
execution_order: list["GraphNode"] = field(default_factory=list)
+ start_time: float = field(default_factory=time.time)
# Results
results: dict[str, NodeResult] = field(default_factory=dict)
@@ -69,6 +72,27 @@ class GraphState:
edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list)
entry_points: list["GraphNode"] = field(default_factory=list)
+ def should_continue(
+ self,
+ max_node_executions: Optional[int],
+ execution_timeout: Optional[float],
+ ) -> Tuple[bool, str]:
+ """Check if the graph should continue execution.
+
+ Returns: (should_continue, reason)
+ """
+ # Check node execution limit (only if set)
+ if max_node_executions is not None and len(self.execution_order) >= max_node_executions:
+ return False, f"Max node executions reached: {max_node_executions}"
+
+ # Check timeout (only if set)
+ if execution_timeout is not None:
+ elapsed = time.time() - self.start_time
+ if elapsed > execution_timeout:
+ return False, f"Execution timed out: {execution_timeout}s"
+
+ return True, "Continuing"
+
@dataclass
class GraphResult(MultiAgentResult):
@@ -117,6 +141,33 @@ class GraphNode:
execution_status: Status = Status.PENDING
result: NodeResult | None = None
execution_time: int = 0
+ _initial_messages: Messages = field(default_factory=list, init=False)
+ _initial_state: AgentState = field(default_factory=AgentState, init=False)
+
+ def __post_init__(self) -> None:
+ """Capture initial executor state after initialization."""
+ # Deep copy the initial messages and state to preserve them
+ if hasattr(self.executor, "messages"):
+ self._initial_messages = copy.deepcopy(self.executor.messages)
+
+ if hasattr(self.executor, "state") and hasattr(self.executor.state, "get"):
+ self._initial_state = AgentState(self.executor.state.get())
+
+ def reset_executor_state(self) -> None:
+ """Reset GraphNode executor state to initial state when graph was created.
+
+ This is useful when nodes are executed multiple times and need to start
+ fresh on each execution, providing stateless behavior.
+ """
+ if hasattr(self.executor, "messages"):
+ self.executor.messages = copy.deepcopy(self._initial_messages)
+
+ if hasattr(self.executor, "state"):
+ self.executor.state = AgentState(self._initial_state.get())
+
+ # Reset execution status
+ self.execution_status = Status.PENDING
+ self.result = None
def __hash__(self) -> int:
"""Return hash for GraphNode based on node_id."""
@@ -164,6 +215,12 @@ def __init__(self) -> None:
self.edges: set[GraphEdge] = set()
self.entry_points: set[GraphNode] = set()
+ # Configuration options
+ self._max_node_executions: Optional[int] = None
+ self._execution_timeout: Optional[float] = None
+ self._node_timeout: Optional[float] = None
+ self._reset_on_revisit: bool = False
+
def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode:
"""Add an Agent or MultiAgentBase instance as a node to the graph."""
_validate_node_executor(executor, self.nodes)
@@ -213,8 +270,48 @@ def set_entry_point(self, node_id: str) -> "GraphBuilder":
self.entry_points.add(self.nodes[node_id])
return self
+ def reset_on_revisit(self, enabled: bool = True) -> "GraphBuilder":
+ """Control whether nodes reset their state when revisited.
+
+ When enabled, nodes will reset their messages and state to initial values
+ each time they are revisited (re-executed). This is useful for stateless
+ behavior where nodes should start fresh on each revisit.
+
+ Args:
+ enabled: Whether to reset node state when revisited (default: True)
+ """
+ self._reset_on_revisit = enabled
+ return self
+
+ def set_max_node_executions(self, max_executions: int) -> "GraphBuilder":
+ """Set maximum number of node executions allowed.
+
+ Args:
+ max_executions: Maximum total node executions (None for no limit)
+ """
+ self._max_node_executions = max_executions
+ return self
+
+ def set_execution_timeout(self, timeout: float) -> "GraphBuilder":
+ """Set total execution timeout.
+
+ Args:
+ timeout: Total execution timeout in seconds (None for no limit)
+ """
+ self._execution_timeout = timeout
+ return self
+
+ def set_node_timeout(self, timeout: float) -> "GraphBuilder":
+ """Set individual node execution timeout.
+
+ Args:
+ timeout: Individual node timeout in seconds (None for no limit)
+ """
+ self._node_timeout = timeout
+ return self
+
def build(self) -> "Graph":
- """Build and validate the graph."""
+ """Build and validate the graph with configured settings."""
if not self.nodes:
raise ValueError("Graph must contain at least one node")
@@ -230,44 +327,53 @@ def build(self) -> "Graph":
# Validate entry points and check for cycles
self._validate_graph()
- return Graph(nodes=self.nodes.copy(), edges=self.edges.copy(), entry_points=self.entry_points.copy())
+ return Graph(
+ nodes=self.nodes.copy(),
+ edges=self.edges.copy(),
+ entry_points=self.entry_points.copy(),
+ max_node_executions=self._max_node_executions,
+ execution_timeout=self._execution_timeout,
+ node_timeout=self._node_timeout,
+ reset_on_revisit=self._reset_on_revisit,
+ )
def _validate_graph(self) -> None:
- """Validate graph structure and detect cycles."""
+ """Validate graph structure."""
# Validate entry points exist
entry_point_ids = {node.node_id for node in self.entry_points}
invalid_entries = entry_point_ids - set(self.nodes.keys())
if invalid_entries:
raise ValueError(f"Entry points not found in nodes: {invalid_entries}")
- # Check for cycles using DFS with color coding
- WHITE, GRAY, BLACK = 0, 1, 2
- colors = {node_id: WHITE for node_id in self.nodes}
-
- def has_cycle_from(node_id: str) -> bool:
- if colors[node_id] == GRAY:
- return True # Back edge found - cycle detected
- if colors[node_id] == BLACK:
- return False
-
- colors[node_id] = GRAY
- # Check all outgoing edges for cycles
- for edge in self.edges:
- if edge.from_node.node_id == node_id and has_cycle_from(edge.to_node.node_id):
- return True
- colors[node_id] = BLACK
- return False
-
- # Check for cycles from each unvisited node
- if any(colors[node_id] == WHITE and has_cycle_from(node_id) for node_id in self.nodes):
- raise ValueError("Graph contains cycles - must be a directed acyclic graph")
+ # Warn about potential infinite loops if no execution limits are set
+ if self._max_node_executions is None and self._execution_timeout is None:
+ logger.warning("Graph without execution limits may run indefinitely if cycles exist")
class Graph(MultiAgentBase):
- """Directed Acyclic Graph multi-agent orchestration."""
+ """Directed Graph multi-agent orchestration with configurable revisit behavior."""
- def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_points: set[GraphNode]) -> None:
- """Initialize Graph."""
+ def __init__(
+ self,
+ nodes: dict[str, GraphNode],
+ edges: set[GraphEdge],
+ entry_points: set[GraphNode],
+ max_node_executions: Optional[int] = None,
+ execution_timeout: Optional[float] = None,
+ node_timeout: Optional[float] = None,
+ reset_on_revisit: bool = False,
+ ) -> None:
+ """Initialize Graph with execution limits and reset behavior.
+
+ Args:
+ nodes: Dictionary of node_id to GraphNode
+ edges: Set of GraphEdge objects
+ entry_points: Set of GraphNode objects that are entry points
+ max_node_executions: Maximum total node executions (default: None - no limit)
+ execution_timeout: Total execution timeout in seconds (default: None - no limit)
+ node_timeout: Individual node timeout in seconds (default: None - no limit)
+ reset_on_revisit: Whether to reset node state when revisited (default: False)
+ """
super().__init__()
# Validate nodes for duplicate instances
@@ -276,6 +382,10 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi
self.nodes = nodes
self.edges = edges
self.entry_points = entry_points
+ self.max_node_executions = max_node_executions
+ self.execution_timeout = execution_timeout
+ self.node_timeout = node_timeout
+ self.reset_on_revisit = reset_on_revisit
self.state = GraphState()
self.tracer = get_tracer()
@@ -294,20 +404,34 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> G
logger.debug("task=<%s> | starting graph execution", task)
# Initialize state
+ start_time = time.time()
self.state = GraphState(
status=Status.EXECUTING,
task=task,
total_nodes=len(self.nodes),
edges=[(edge.from_node, edge.to_node) for edge in self.edges],
entry_points=list(self.entry_points),
+ start_time=start_time,
)
- start_time = time.time()
span = self.tracer.start_multiagent_span(task, "graph")
with trace_api.use_span(span, end_on_exit=True):
try:
+ logger.debug(
+ "max_node_executions=<%s>, execution_timeout=<%s>s, node_timeout=<%s>s | graph execution config",
+ self.max_node_executions or "None",
+ self.execution_timeout or "None",
+ self.node_timeout or "None",
+ )
+
await self._execute_graph()
- self.state.status = Status.COMPLETED
+
+ # Set final status based on execution results
+ if self.state.failed_nodes:
+ self.state.status = Status.FAILED
+ elif self.state.status == Status.EXECUTING: # Only set to COMPLETED if still executing and no failures
+ self.state.status = Status.COMPLETED
+
logger.debug("status=<%s> | graph execution completed", self.state.status)
except Exception:
@@ -335,6 +459,16 @@ async def _execute_graph(self) -> None:
ready_nodes = list(self.entry_points)
while ready_nodes:
+ # Check execution limits before continuing
+ should_continue, reason = self.state.should_continue(
+ max_node_executions=self.max_node_executions,
+ execution_timeout=self.execution_timeout,
+ )
+ if not should_continue:
+ self.state.status = Status.FAILED
+ logger.debug("reason=<%s> | stopping execution", reason)
+ return # Let the top-level exception handler deal with it
+
current_batch = ready_nodes.copy()
ready_nodes.clear()
@@ -386,7 +520,14 @@ def _is_node_ready_with_conditions(self, node: GraphNode) -> bool:
return False
async def _execute_node(self, node: GraphNode) -> None:
- """Execute a single node with error handling."""
+ """Execute a single node with error handling and timeout protection."""
+ # Reset the node's state if reset_on_revisit is enabled and it's being revisited
+ if self.reset_on_revisit and node in self.state.completed_nodes:
+ logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id)
+ node.reset_executor_state()
+ # Remove from completed nodes since we're re-executing it
+ self.state.completed_nodes.remove(node)
+
node.execution_status = Status.EXECUTING
logger.debug("node_id=<%s> | executing node", node.node_id)
@@ -395,42 +536,65 @@ async def _execute_node(self, node: GraphNode) -> None:
# Build node input from satisfied dependencies
node_input = self._build_node_input(node)
- # Execute based on node type and create unified NodeResult
- if isinstance(node.executor, MultiAgentBase):
- multi_agent_result = await node.executor.invoke_async(node_input)
-
- # Create NodeResult with MultiAgentResult directly
- node_result = NodeResult(
- result=multi_agent_result, # type is MultiAgentResult
- execution_time=multi_agent_result.execution_time,
- status=Status.COMPLETED,
- accumulated_usage=multi_agent_result.accumulated_usage,
- accumulated_metrics=multi_agent_result.accumulated_metrics,
- execution_count=multi_agent_result.execution_count,
- )
+ # Execute with timeout protection (only if node_timeout is set)
+ try:
+ # Execute based on node type and create unified NodeResult
+ if isinstance(node.executor, MultiAgentBase):
+ if self.node_timeout is not None:
+ multi_agent_result = await asyncio.wait_for(
+ node.executor.invoke_async(node_input),
+ timeout=self.node_timeout,
+ )
+ else:
+ multi_agent_result = await node.executor.invoke_async(node_input)
+
+ # Create NodeResult with MultiAgentResult directly
+ node_result = NodeResult(
+ result=multi_agent_result, # type is MultiAgentResult
+ execution_time=multi_agent_result.execution_time,
+ status=Status.COMPLETED,
+ accumulated_usage=multi_agent_result.accumulated_usage,
+ accumulated_metrics=multi_agent_result.accumulated_metrics,
+ execution_count=multi_agent_result.execution_count,
+ )
- elif isinstance(node.executor, Agent):
- agent_response = await node.executor.invoke_async(node_input)
-
- # Extract metrics from agent response
- usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
- metrics = Metrics(latencyMs=0)
- if hasattr(agent_response, "metrics") and agent_response.metrics:
- if hasattr(agent_response.metrics, "accumulated_usage"):
- usage = agent_response.metrics.accumulated_usage
- if hasattr(agent_response.metrics, "accumulated_metrics"):
- metrics = agent_response.metrics.accumulated_metrics
-
- node_result = NodeResult(
- result=agent_response, # type is AgentResult
- execution_time=round((time.time() - start_time) * 1000),
- status=Status.COMPLETED,
- accumulated_usage=usage,
- accumulated_metrics=metrics,
- execution_count=1,
+ elif isinstance(node.executor, Agent):
+ if self.node_timeout is not None:
+ agent_response = await asyncio.wait_for(
+ node.executor.invoke_async(node_input),
+ timeout=self.node_timeout,
+ )
+ else:
+ agent_response = await node.executor.invoke_async(node_input)
+
+ # Extract metrics from agent response
+ usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
+ metrics = Metrics(latencyMs=0)
+ if hasattr(agent_response, "metrics") and agent_response.metrics:
+ if hasattr(agent_response.metrics, "accumulated_usage"):
+ usage = agent_response.metrics.accumulated_usage
+ if hasattr(agent_response.metrics, "accumulated_metrics"):
+ metrics = agent_response.metrics.accumulated_metrics
+
+ node_result = NodeResult(
+ result=agent_response, # type is AgentResult
+ execution_time=round((time.time() - start_time) * 1000),
+ status=Status.COMPLETED,
+ accumulated_usage=usage,
+ accumulated_metrics=metrics,
+ execution_count=1,
+ )
+ else:
+ raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported")
+
+ except asyncio.TimeoutError:
+ timeout_msg = f"Node '{node.node_id}' execution timed out after {self.node_timeout}s"
+ logger.exception(
+ "node=<%s>, timeout=<%s>s | node execution timed out after timeout",
+ node.node_id,
+ self.node_timeout,
)
- else:
- raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported")
+ raise Exception(timeout_msg) from None
# Mark as completed
node.execution_status = Status.COMPLETED
diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py
index cb74f515c..c60361da8 100644
--- a/tests/strands/multiagent/test_graph.py
+++ b/tests/strands/multiagent/test_graph.py
@@ -1,8 +1,11 @@
+import asyncio
+import time
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from strands.agent import Agent, AgentResult
+from strands.agent.state import AgentState
from strands.hooks import AgentInitializedEvent
from strands.hooks.registry import HookProvider, HookRegistry
from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult
@@ -251,7 +254,8 @@ class UnsupportedExecutor:
builder.add_node(UnsupportedExecutor(), "unsupported_node")
graph = builder.build()
- with pytest.raises(ValueError, match="Node 'unsupported_node' of type.*is not supported"):
+ # Execute the graph - should raise ValueError due to unsupported node type
+ with pytest.raises(ValueError, match="Node 'unsupported_node' of type .* is not supported"):
await graph.invoke_async("test task")
mock_strands_tracer.start_multiagent_span.assert_called()
@@ -285,12 +289,10 @@ async def mock_invoke_failure(*args, **kwargs):
graph = builder.build()
+ # Execute the graph - should raise Exception due to failing agent
with pytest.raises(Exception, match="Simulated failure"):
await graph.invoke_async("Test error handling")
- assert graph.state.status == Status.FAILED
- assert any(node.node_id == "fail_node" for node in graph.state.failed_nodes)
- assert len(graph.state.completed_nodes) == 0
mock_strands_tracer.start_multiagent_span.assert_called()
mock_use_span.assert_called_once()
@@ -314,6 +316,91 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span):
mock_use_span.assert_called_once()
+@pytest.mark.asyncio
+async def test_cyclic_graph_execution(mock_strands_tracer, mock_use_span):
+ """Test execution of a graph with cycles."""
+ # Create mock agents with state tracking
+ agent_a = create_mock_agent("agent_a", "Agent A response")
+ agent_b = create_mock_agent("agent_b", "Agent B response")
+ agent_c = create_mock_agent("agent_c", "Agent C response")
+
+ # Add state to agents to track execution
+ agent_a.state = AgentState()
+ agent_b.state = AgentState()
+ agent_c.state = AgentState()
+
+ # Create a spy to track reset calls
+ reset_spy = MagicMock()
+
+ # Create a graph with a cycle: A -> B -> C -> A
+ builder = GraphBuilder()
+ builder.add_node(agent_a, "a")
+ builder.add_node(agent_b, "b")
+ builder.add_node(agent_c, "c")
+ builder.add_edge("a", "b")
+ builder.add_edge("b", "c")
+ builder.add_edge("c", "a") # Creates cycle
+ builder.set_entry_point("a")
+ builder.reset_on_revisit() # Enable state reset on revisit
+
+ # Patch the reset_executor_state method to track calls
+ original_reset = GraphNode.reset_executor_state
+
+ def spy_reset(self):
+ reset_spy(self.node_id)
+ original_reset(self)
+
+ with patch.object(GraphNode, "reset_executor_state", spy_reset):
+ graph = builder.build()
+
+ # Set a maximum iteration limit to prevent infinite loops
+ # but ensure we go through the cycle at least twice
+ # This value is used in the LimitedGraph class below
+
+ # Execute the graph with a task that will cause it to cycle
+ result = await graph.invoke_async("Test cyclic graph execution")
+
+ # Verify that the graph executed successfully
+ assert result.status == Status.COMPLETED
+
+ # Verify that each agent was called at least once
+ agent_a.invoke_async.assert_called()
+ agent_b.invoke_async.assert_called()
+ agent_c.invoke_async.assert_called()
+
+ # Verify that the execution order includes all nodes
+ assert len(result.execution_order) >= 3
+ assert any(node.node_id == "a" for node in result.execution_order)
+ assert any(node.node_id == "b" for node in result.execution_order)
+ assert any(node.node_id == "c" for node in result.execution_order)
+
+ # Verify that node state was reset during cyclic execution
+ # If we have more than 3 nodes in execution_order, at least one node was revisited
+ if len(result.execution_order) > 3:
+ # Check that reset_executor_state was called for revisited nodes
+ reset_spy.assert_called()
+
+ # Count occurrences of each node in execution order
+ node_counts = {}
+ for node in result.execution_order:
+ node_counts[node.node_id] = node_counts.get(node.node_id, 0) + 1
+
+ # At least one node should appear multiple times
+ assert any(count > 1 for count in node_counts.values()), "No node was revisited in the cycle"
+
+ # For each node that appears multiple times, verify reset was called
+ for node_id, count in node_counts.items():
+ if count > 1:
+ # Check that reset was called at least (count-1) times for this node
+ reset_calls = sum(1 for call in reset_spy.call_args_list if call[0][0] == node_id)
+ assert reset_calls >= count - 1, (
+ f"Node {node_id} appeared {count} times but reset was called {reset_calls} times"
+ )
+
+ # Verify all nodes were completed
+ assert result.completed_nodes == 3
+
+
def test_graph_builder_validation():
"""Test GraphBuilder validation and error handling."""
# Test empty graph validation
@@ -343,7 +430,11 @@ def test_graph_builder_validation():
node2 = GraphNode("node2", duplicate_agent) # Same agent instance
nodes = {"node1": node1, "node2": node2}
with pytest.raises(ValueError, match="Duplicate node instance detected"):
- Graph(nodes=nodes, edges=set(), entry_points=set())
+ Graph(
+ nodes=nodes,
+ edges=set(),
+ entry_points=set(),
+ )
# Test edge validation with non-existent nodes
builder = GraphBuilder()
@@ -368,7 +459,7 @@ def test_graph_builder_validation():
with pytest.raises(ValueError, match="Entry points not found in nodes"):
builder.build()
- # Test cycle detection
+ # Test cycle detection (should be forbidden by default)
builder = GraphBuilder()
builder.add_node(agent1, "a")
builder.add_node(agent2, "b")
@@ -378,8 +469,9 @@ def test_graph_builder_validation():
builder.add_edge("c", "a") # Creates cycle
builder.set_entry_point("a")
- with pytest.raises(ValueError, match="Graph contains cycles"):
- builder.build()
+ # Should succeed - cycles are now allowed by default
+ graph = builder.build()
+ assert any(node.node_id == "a" for node in graph.entry_points)
# Test auto-detection of entry points
builder = GraphBuilder()
@@ -400,6 +492,259 @@ def test_graph_builder_validation():
with pytest.raises(ValueError, match="No entry points found - all nodes have dependencies"):
builder.build()
+ # Test custom execution limits and reset_on_revisit
+ builder = GraphBuilder()
+ builder.add_node(agent1, "test_node")
+ graph = (
+ builder.set_max_node_executions(10)
+ .set_execution_timeout(300.0)
+ .set_node_timeout(60.0)
+ .reset_on_revisit()
+ .build()
+ )
+ assert graph.max_node_executions == 10
+ assert graph.execution_timeout == 300.0
+ assert graph.node_timeout == 60.0
+ assert graph.reset_on_revisit is True
+
+ # Test default execution limits and reset_on_revisit (None and False)
+ builder = GraphBuilder()
+ builder.add_node(agent1, "test_node")
+ graph = builder.build()
+ assert graph.max_node_executions is None
+ assert graph.execution_timeout is None
+ assert graph.node_timeout is None
+ assert graph.reset_on_revisit is False
+
+
+@pytest.mark.asyncio
+async def test_graph_execution_limits(mock_strands_tracer, mock_use_span):
+ """Test graph execution limits (max_node_executions and execution_timeout)."""
+ # Test with a simple linear graph first to verify limits work
+ agent_a = create_mock_agent("agent_a", "Response A")
+ agent_b = create_mock_agent("agent_b", "Response B")
+ agent_c = create_mock_agent("agent_c", "Response C")
+
+ # Create a linear graph: a -> b -> c
+ builder = GraphBuilder()
+ builder.add_node(agent_a, "a")
+ builder.add_node(agent_b, "b")
+ builder.add_node(agent_c, "c")
+ builder.add_edge("a", "b")
+ builder.add_edge("b", "c")
+ builder.set_entry_point("a")
+
+ # Test with no limits (backward compatibility) - should complete normally
+ graph = builder.build() # No limits specified
+ result = await graph.invoke_async("Test execution")
+ assert result.status == Status.COMPLETED
+ assert len(result.execution_order) == 3 # All 3 nodes should execute
+
+ # Test with limit that allows completion
+ builder = GraphBuilder()
+ builder.add_node(agent_a, "a")
+ builder.add_node(agent_b, "b")
+ builder.add_node(agent_c, "c")
+ builder.add_edge("a", "b")
+ builder.add_edge("b", "c")
+ builder.set_entry_point("a")
+ graph = builder.set_max_node_executions(5).set_execution_timeout(900.0).set_node_timeout(300.0).build()
+ result = await graph.invoke_async("Test execution")
+ assert result.status == Status.COMPLETED
+ assert len(result.execution_order) == 3 # All 3 nodes should execute
+
+ # Test with limit that prevents full completion
+ builder = GraphBuilder()
+ builder.add_node(agent_a, "a")
+ builder.add_node(agent_b, "b")
+ builder.add_node(agent_c, "c")
+ builder.add_edge("a", "b")
+ builder.add_edge("b", "c")
+ builder.set_entry_point("a")
+ graph = builder.set_max_node_executions(2).set_execution_timeout(900.0).set_node_timeout(300.0).build()
+ result = await graph.invoke_async("Test execution limit")
+ assert result.status == Status.FAILED # Should fail due to limit
+ assert len(result.execution_order) == 2 # Should stop at 2 executions
+
+ # Test execution timeout by manipulating start time (like Swarm does)
+ timeout_agent_a = create_mock_agent("timeout_agent_a", "Response A")
+ timeout_agent_b = create_mock_agent("timeout_agent_b", "Response B")
+
+ # Create a cyclic graph that would run indefinitely
+ builder = GraphBuilder()
+ builder.add_node(timeout_agent_a, "a")
+ builder.add_node(timeout_agent_b, "b")
+ builder.add_edge("a", "b")
+ builder.add_edge("b", "a") # Creates cycle
+ builder.set_entry_point("a")
+
+ # Enable reset_on_revisit so the cycle can continue
+ graph = builder.reset_on_revisit(True).set_execution_timeout(5.0).set_max_node_executions(100).build()
+
+ # Manipulate the start time to simulate timeout (like Swarm does)
+ result = await graph.invoke_async("Test execution timeout")
+ # Manually set start time to simulate timeout condition
+ graph.state.start_time = time.time() - 10 # Set start time to 10 seconds ago
+
+ # Check the timeout logic directly
+ should_continue, reason = graph.state.should_continue(max_node_executions=100, execution_timeout=5.0)
+ assert should_continue is False
+ assert "Execution timed out" in reason
+
+ # builder = GraphBuilder()
+ # builder.add_node(slow_agent, "slow")
+ # graph = (builder.set_max_node_executions(1000) # High limit to avoid hitting this
+ # .set_execution_timeout(0.05) # Very short execution timeout
+ # .set_node_timeout(300.0)
+ # .build())
+
+ # result = await graph.invoke_async("Test timeout")
+ # assert result.status == Status.FAILED # Should fail due to timeout
+
+ mock_strands_tracer.start_multiagent_span.assert_called()
+ mock_use_span.assert_called()
+
+
+@pytest.mark.asyncio
+async def test_graph_node_timeout(mock_strands_tracer, mock_use_span):
+ """Test individual node timeout functionality."""
+
+ # Create a mock agent that takes longer than the node timeout
+ timeout_agent = create_mock_agent("timeout_agent", "Should timeout")
+
+ async def timeout_invoke(*args, **kwargs):
+ await asyncio.sleep(0.2) # Longer than node timeout
+ return timeout_agent.return_value
+
+ timeout_agent.invoke_async = AsyncMock(side_effect=timeout_invoke)
+
+ builder = GraphBuilder()
+ builder.add_node(timeout_agent, "timeout_node")
+
+ # Test with no timeout (backward compatibility) - should complete normally
+ graph = builder.build() # No timeout specified
+ result = await graph.invoke_async("Test no timeout")
+ assert result.status == Status.COMPLETED
+ assert result.completed_nodes == 1
+
+ # Test with very short node timeout - should raise timeout exception
+ builder = GraphBuilder()
+ builder.add_node(timeout_agent, "timeout_node")
+ graph = builder.set_max_node_executions(50).set_execution_timeout(900.0).set_node_timeout(0.1).build()
+
+ # Execute the graph - should raise Exception due to timeout
+ with pytest.raises(Exception, match="Node 'timeout_node' execution timed out after 0.1s"):
+ await graph.invoke_async("Test node timeout")
+
+ mock_strands_tracer.start_multiagent_span.assert_called()
+ mock_use_span.assert_called()
+
+
+@pytest.mark.asyncio
+async def test_backward_compatibility_no_limits():
+ """Test that graphs with no limits specified work exactly as before."""
+ # Create simple agents
+ agent_a = create_mock_agent("agent_a", "Response A")
+ agent_b = create_mock_agent("agent_b", "Response B")
+
+ # Create a simple linear graph
+ builder = GraphBuilder()
+ builder.add_node(agent_a, "a")
+ builder.add_node(agent_b, "b")
+ builder.add_edge("a", "b")
+ builder.set_entry_point("a")
+
+ # Build without specifying any limits - should work exactly as before
+ graph = builder.build()
+
+ # Verify the limits are None (no limits)
+ assert graph.max_node_executions is None
+ assert graph.execution_timeout is None
+ assert graph.node_timeout is None
+
+ # Execute the graph - should complete normally
+ result = await graph.invoke_async("Test backward compatibility")
+ assert result.status == Status.COMPLETED
+ assert len(result.execution_order) == 2 # Both nodes should execute
+
+
+@pytest.mark.asyncio
+async def test_node_reset_executor_state():
+ """Test that GraphNode.reset_executor_state properly resets node state."""
+ # Create a mock agent with state
+ agent = create_mock_agent("test_agent", "Test response")
+ agent.state = AgentState()
+ agent.state.set("test_key", "test_value")
+ agent.messages = [{"role": "system", "content": "Initial system message"}]
+
+ # Create a GraphNode with this agent
+ node = GraphNode("test_node", agent)
+
+ # Verify initial state is captured during initialization
+ assert len(node._initial_messages) == 1
+ assert node._initial_messages[0]["role"] == "system"
+ assert node._initial_messages[0]["content"] == "Initial system message"
+
+ # Modify agent state and messages after initialization
+ agent.state.set("new_key", "new_value")
+ agent.messages.append({"role": "user", "content": "New message"})
+
+ # Also modify execution status and result
+ node.execution_status = Status.COMPLETED
+ node.result = NodeResult(
+ result="test result",
+ execution_time=100,
+ status=Status.COMPLETED,
+ accumulated_usage={"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
+ accumulated_metrics={"latencyMs": 100},
+ execution_count=1,
+ )
+
+ # Verify state was modified
+ assert len(agent.messages) == 2
+ assert agent.state.get("new_key") == "new_value"
+ assert node.execution_status == Status.COMPLETED
+ assert node.result is not None
+
+ # Reset the executor state
+ node.reset_executor_state()
+
+ # Verify messages were reset to initial values
+ assert len(agent.messages) == 1
+ assert agent.messages[0]["role"] == "system"
+ assert agent.messages[0]["content"] == "Initial system message"
+
+ # Verify agent state was reset
+ # The test_key should be gone since it wasn't in the initial state
+ assert agent.state.get("new_key") is None
+
+ # Verify execution status is reset
+ assert node.execution_status == Status.PENDING
+ assert node.result is None
+
+ # Test with MultiAgentBase executor
+ multi_agent = create_mock_multi_agent("multi_agent")
+ multi_agent_node = GraphNode("multi_node", multi_agent)
+
+ # Since MultiAgentBase doesn't have messages or state attributes,
+ # reset_executor_state should not fail
+ multi_agent_node.execution_status = Status.COMPLETED
+ multi_agent_node.result = NodeResult(
+ result="test result",
+ execution_time=100,
+ status=Status.COMPLETED,
+ accumulated_usage={},
+ accumulated_metrics={},
+ execution_count=1,
+ )
+
+ # Reset should work without errors
+ multi_agent_node.reset_executor_state()
+
+ # Verify execution status is reset
+ assert multi_agent_node.execution_status == Status.PENDING
+ assert multi_agent_node.result is None
+
def test_graph_dataclasses_and_enums():
"""Test dataclass initialization, properties, and enum behavior."""
@@ -417,6 +762,7 @@ def test_graph_dataclasses_and_enums():
assert state.task == ""
assert state.accumulated_usage == {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}
assert state.execution_count == 0
+ assert state.start_time > 0 # Should be set by default factory
# Test GraphState with custom values
state = GraphState(status=Status.EXECUTING, task="custom task", total_nodes=5, execution_count=3)
@@ -540,9 +886,222 @@ def register_hooks(self, registry, **kwargs):
# Test with session manager in Graph constructor
node_with_session = GraphNode("node_with_session", agent_with_session)
with pytest.raises(ValueError, match="Session persistence is not supported for Graph agents yet"):
- Graph(nodes={"node_with_session": node_with_session}, edges=set(), entry_points=set())
+ Graph(
+ nodes={"node_with_session": node_with_session},
+ edges=set(),
+ entry_points=set(),
+ )
# Test with callbacks in Graph constructor
node_with_hooks = GraphNode("node_with_hooks", agent_with_hooks)
with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"):
- Graph(nodes={"node_with_hooks": node_with_hooks}, edges=set(), entry_points=set())
+ Graph(
+ nodes={"node_with_hooks": node_with_hooks},
+ edges=set(),
+ entry_points=set(),
+ )
+
+
+@pytest.mark.asyncio
+async def test_controlled_cyclic_execution():
+ """Test cyclic graph execution with controlled cycle count to verify state reset."""
+
+ # Create a stateful agent that tracks its own execution count
+ class StatefulAgent(Agent):
+ def __init__(self, name):
+ super().__init__()
+ self.name = name
+ self.state = AgentState()
+ self.state.set("execution_count", 0)
+ self.messages = []
+ self._session_manager = None
+ self.hooks = HookRegistry()
+
+ async def invoke_async(self, input_data):
+ # Increment execution count in state
+ count = self.state.get("execution_count") or 0
+ self.state.set("execution_count", count + 1)
+
+ return AgentResult(
+ message={"role": "assistant", "content": [{"text": f"{self.name} response (execution {count + 1})"}]},
+ stop_reason="end_turn",
+ state={},
+ metrics=Mock(
+ accumulated_usage={"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
+ accumulated_metrics={"latencyMs": 100.0},
+ ),
+ )
+
+ # Create agents
+ agent_a = StatefulAgent("agent_a")
+ agent_b = StatefulAgent("agent_b")
+
+ # Create a graph with a simple cycle: A -> B -> A
+ builder = GraphBuilder()
+ builder.add_node(agent_a, "a")
+ builder.add_node(agent_b, "b")
+ builder.add_edge("a", "b")
+ builder.add_edge("b", "a") # Creates cycle
+ builder.set_entry_point("a")
+ builder.reset_on_revisit() # Enable state reset on revisit
+
+ # Build with limited max_node_executions to prevent infinite loop
+ graph = builder.set_max_node_executions(3).build()
+
+ # Execute the graph
+ result = await graph.invoke_async("Test controlled cyclic execution")
+
+ # With a 2-node cycle and limit of 3, we should see either completion or failure
+ # The exact behavior depends on how the cycle detection works
+ if result.status == Status.COMPLETED:
+ # If it completed, verify it executed some nodes
+ assert len(result.execution_order) >= 2
+ assert result.execution_order[0].node_id == "a"
+ elif result.status == Status.FAILED:
+ # If it failed due to limits, verify it hit the limit
+ assert len(result.execution_order) == 3 # Should stop at exactly 3 executions
+ assert result.execution_order[0].node_id == "a"
+ else:
+ # Should be either completed or failed
+ raise AssertionError(f"Unexpected status: {result.status}")
+
+ # Most importantly, verify that state was reset properly between executions
+ # The state.execution_count should be set for both agents after execution
+ assert agent_a.state.get("execution_count") >= 1 # Node A executed at least once
+ assert agent_b.state.get("execution_count") >= 1 # Node B executed at least once
+
+
+def test_reset_on_revisit_backward_compatibility():
+ """Test that reset_on_revisit provides backward compatibility by default."""
+ agent1 = create_mock_agent("agent1")
+ agent2 = create_mock_agent("agent2")
+
+ # Test default behavior - reset_on_revisit is False by default
+ builder = GraphBuilder()
+ builder.add_node(agent1, "a")
+ builder.add_node(agent2, "b")
+ builder.add_edge("a", "b")
+ builder.set_entry_point("a")
+
+ graph = builder.build()
+ assert graph.reset_on_revisit is False
+
+ # Test reset_on_revisit with True
+ builder = GraphBuilder()
+ builder.add_node(agent1, "a")
+ builder.add_node(agent2, "b")
+ builder.add_edge("a", "b")
+ builder.set_entry_point("a")
+ builder.reset_on_revisit(True)
+
+ graph = builder.build()
+ assert graph.reset_on_revisit is True
+
+ # Test reset_on_revisit with False explicitly
+ builder = GraphBuilder()
+ builder.add_node(agent1, "a")
+ builder.add_node(agent2, "b")
+ builder.add_edge("a", "b")
+ builder.set_entry_point("a")
+ builder.reset_on_revisit(False)
+
+ graph = builder.build()
+ assert graph.reset_on_revisit is False
+
+
+def test_reset_on_revisit_method_chaining():
+ """Test that reset_on_revisit method returns GraphBuilder for chaining."""
+ agent1 = create_mock_agent("agent1")
+
+ builder = GraphBuilder()
+ result = builder.reset_on_revisit()
+
+ # Verify method chaining works
+ assert result is builder
+ assert builder._reset_on_revisit is True
+
+ # Test full method chaining
+ builder.add_node(agent1, "test_node")
+ builder.set_max_node_executions(10)
+ graph = builder.build()
+
+ assert graph.reset_on_revisit is True
+ assert graph.max_node_executions == 10
+
+
+@pytest.mark.asyncio
+async def test_linear_graph_behavior():
+ """Test that linear graph behavior works correctly."""
+ agent_a = create_mock_agent("agent_a", "Response A")
+ agent_b = create_mock_agent("agent_b", "Response B")
+
+ # Create linear graph
+ builder = GraphBuilder()
+ builder.add_node(agent_a, "a")
+ builder.add_node(agent_b, "b")
+ builder.add_edge("a", "b")
+ builder.set_entry_point("a")
+
+ graph = builder.build()
+ assert graph.reset_on_revisit is False
+
+ # Execute should work normally
+ result = await graph.invoke_async("Test linear execution")
+ assert result.status == Status.COMPLETED
+ assert len(result.execution_order) == 2
+ assert result.execution_order[0].node_id == "a"
+ assert result.execution_order[1].node_id == "b"
+
+ # Verify agents were called once each (no state reset)
+ agent_a.invoke_async.assert_called_once()
+ agent_b.invoke_async.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_state_reset_only_with_cycles_enabled():
+ """Test that state reset only happens when cycles are enabled."""
+ # Create a mock agent that tracks state modifications
+ agent = create_mock_agent("test_agent", "Test response")
+ agent.state = AgentState()
+ agent.messages = [{"role": "system", "content": "Initial message"}]
+
+ # Create GraphNode
+ node = GraphNode("test_node", agent)
+
+ # Simulate agent being in completed_nodes (as if revisited)
+ from strands.multiagent.graph import GraphState
+
+ state = GraphState()
+ state.completed_nodes.add(node)
+
+ # Create graph with cycles disabled (default)
+ builder = GraphBuilder()
+ builder.add_node(agent, "test_node")
+ graph = builder.build()
+
+ # Mock the _execute_node method to test conditional reset logic
+ import unittest.mock
+
+ with unittest.mock.patch.object(node, "reset_executor_state") as mock_reset:
+ # Simulate the conditional logic from _execute_node
+ if graph.reset_on_revisit and node in state.completed_nodes:
+ node.reset_executor_state()
+ state.completed_nodes.remove(node)
+
+ # With reset_on_revisit disabled, reset should not be called
+ mock_reset.assert_not_called()
+
+ # Now test with reset_on_revisit enabled
+ builder = GraphBuilder()
+ builder.add_node(agent, "test_node")
+ builder.reset_on_revisit()
+ graph = builder.build()
+
+ with unittest.mock.patch.object(node, "reset_executor_state") as mock_reset:
+ # Simulate the conditional logic from _execute_node
+ if graph.reset_on_revisit and node in state.completed_nodes:
+ node.reset_executor_state()
+ state.completed_nodes.remove(node)
+
+ # With reset_on_revisit enabled, reset should be called
+ mock_reset.assert_called_once()
From 72709cf16d40b985d05ecf2ddb2081fbe28d1aa2 Mon Sep 17 00:00:00 2001
From: poshinchen
Date: Mon, 11 Aug 2025 14:35:57 -0400
Subject: [PATCH 028/104] chore: request to include code snippet section (#654)
---
.github/ISSUE_TEMPLATE/bug_report.yml | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
index 3c357173c..b3898b7f7 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -61,9 +61,10 @@ body:
label: Steps to Reproduce
description: Detailed steps to reproduce the behavior
placeholder: |
- 1. Install Strands using...
- 2. Run the command...
- 3. See error...
+ 1. Code Snippet (Minimal reproducible example)
+ 2. Install Strands using...
+ 3. Run the command...
+ 4. See error...
validations:
required: true
- type: textarea
From 8434409a1f85816c6ec42756c79eb05b0914d6d1 Mon Sep 17 00:00:00 2001
From: fhwilton55 <81768750+fhwilton55@users.noreply.github.com>
Date: Tue, 12 Aug 2025 18:16:53 -0400
Subject: [PATCH 029/104] feat: Add configuration option to MCP Client for
server init timeout (#657)
Co-authored-by: Harry Wilton
---
src/strands/tools/mcp/mcp_client.py | 17 ++++++++++++-----
1 file changed, 12 insertions(+), 5 deletions(-)
diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py
index c1aa96df3..7cb03e46f 100644
--- a/src/strands/tools/mcp/mcp_client.py
+++ b/src/strands/tools/mcp/mcp_client.py
@@ -63,17 +63,23 @@ class MCPClient:
from MCP tools, it will be returned as the last item in the content array of the ToolResult.
"""
- def __init__(self, transport_callable: Callable[[], MCPTransport]):
+ def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_timeout: int = 30):
"""Initialize a new MCP Server connection.
Args:
transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple
+ startup_timeout: Timeout after which MCP server initialization should be cancelled
+ Defaults to 30.
"""
+ self._startup_timeout = startup_timeout
+
mcp_instrumentation()
self._session_id = uuid.uuid4()
self._log_debug_with_thread("initializing MCPClient connection")
- self._init_future: futures.Future[None] = futures.Future() # Main thread blocks until future completes
- self._close_event = asyncio.Event() # Do not want to block other threads while close event is false
+ # Main thread blocks until future completesock
+ self._init_future: futures.Future[None] = futures.Future()
+ # Do not want to block other threads while close event is false
+ self._close_event = asyncio.Event()
self._transport_callable = transport_callable
self._background_thread: threading.Thread | None = None
@@ -109,7 +115,7 @@ def start(self) -> "MCPClient":
self._log_debug_with_thread("background thread started, waiting for ready event")
try:
# Blocking main thread until session is initialized in other thread or if the thread stops
- self._init_future.result(timeout=30)
+ self._init_future.result(timeout=self._startup_timeout)
self._log_debug_with_thread("the client initialization was successful")
except futures.TimeoutError as e:
raise MCPClientInitializationError("background thread did not start in 30 seconds") from e
@@ -347,7 +353,8 @@ async def _async_background_thread(self) -> None:
self._log_debug_with_thread("session initialized successfully")
# Store the session for use while we await the close event
self._background_thread_session = session
- self._init_future.set_result(None) # Signal that the session has been created and is ready for use
+ # Signal that the session has been created and is ready for use
+ self._init_future.set_result(None)
self._log_debug_with_thread("waiting for close signal")
# Keep background thread running until signaled to close.
From 49ff22678b27b737658d2b6215365c454bc19db6 Mon Sep 17 00:00:00 2001
From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com>
Date: Wed, 13 Aug 2025 08:58:33 -0400
Subject: [PATCH 030/104] fix: Bedrock hang when exception occurs during
message conversion (#643)
Previously (#642) bedrock would hang during message conversion because the exception was not being caught and thus the queue was always empty. Now all exceptions during conversion are caught
Co-authored-by: Mackenzie Zastrow
---
pyproject.toml | 2 +-
src/strands/models/bedrock.py | 12 ++++++------
tests/strands/models/test_bedrock.py | 9 +++++++++
3 files changed, 16 insertions(+), 7 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 586a956af..d4a4b79dc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -234,8 +234,8 @@ test-integ = [
"hatch test tests_integ {args}"
]
prepare = [
- "hatch fmt --linter",
"hatch fmt --formatter",
+ "hatch fmt --linter",
"hatch run test-lint",
"hatch test --all"
]
diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py
index 4ea1453a4..ace35640a 100644
--- a/src/strands/models/bedrock.py
+++ b/src/strands/models/bedrock.py
@@ -418,14 +418,14 @@ def _stream(
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the model service is throttling requests.
"""
- logger.debug("formatting request")
- request = self.format_request(messages, tool_specs, system_prompt)
- logger.debug("request=<%s>", request)
+ try:
+ logger.debug("formatting request")
+ request = self.format_request(messages, tool_specs, system_prompt)
+ logger.debug("request=<%s>", request)
- logger.debug("invoking model")
- streaming = self.config.get("streaming", True)
+ logger.debug("invoking model")
+ streaming = self.config.get("streaming", True)
- try:
logger.debug("got response from model")
if streaming:
response = self.client.converse_stream(**request)
diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py
index 0a2846adf..09e508845 100644
--- a/tests/strands/models/test_bedrock.py
+++ b/tests/strands/models/test_bedrock.py
@@ -419,6 +419,15 @@ async def test_stream_throttling_exception_from_event_stream_error(bedrock_clien
)
+@pytest.mark.asyncio
+async def test_stream_with_invalid_content_throws(bedrock_client, model, alist):
+ # We used to hang on None, so ensure we don't regress: https://github.com/strands-agents/sdk-python/issues/642
+ messages = [{"role": "user", "content": None}]
+
+ with pytest.raises(TypeError):
+ await alist(model.stream(messages))
+
+
@pytest.mark.asyncio
async def test_stream_throttling_exception_from_general_exception(bedrock_client, model, messages, alist):
error_message = "ThrottlingException: Rate exceeded for ConverseStream"
From 04557562eb4345abb65bf056c2889b1586dab277 Mon Sep 17 00:00:00 2001
From: poshinchen
Date: Wed, 13 Aug 2025 17:20:34 -0400
Subject: [PATCH 031/104] feat: add structured_output_span (#655)
* feat: add structured_output_span
---
src/strands/agent/agent.py | 65 ++++++++++++++++++++-----------
tests/strands/agent/test_agent.py | 39 +++++++++++++++++++
2 files changed, 82 insertions(+), 22 deletions(-)
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index 2022142c6..43b5cbf8c 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -33,7 +33,7 @@
from ..models.model import Model
from ..session.session_manager import SessionManager
from ..telemetry.metrics import EventLoopMetrics
-from ..telemetry.tracer import get_tracer
+from ..telemetry.tracer import get_tracer, serialize
from ..tools.registry import ToolRegistry
from ..tools.watcher import ToolWatcher
from ..types.content import ContentBlock, Message, Messages
@@ -445,27 +445,48 @@ async def structured_output_async(
ValueError: If no conversation history or prompt is provided.
"""
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
-
- try:
- if not self.messages and not prompt:
- raise ValueError("No conversation history or prompt provided")
-
- # Create temporary messages array if prompt is provided
- if prompt:
- content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
- temp_messages = self.messages + [{"role": "user", "content": content}]
- else:
- temp_messages = self.messages
-
- events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
- async for event in events:
- if "callback" in event:
- self.callback_handler(**cast(dict, event["callback"]))
-
- return event["output"]
-
- finally:
- self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
+ with self.tracer.tracer.start_as_current_span(
+ "execute_structured_output", kind=trace_api.SpanKind.CLIENT
+ ) as structured_output_span:
+ try:
+ if not self.messages and not prompt:
+ raise ValueError("No conversation history or prompt provided")
+ # Create temporary messages array if prompt is provided
+ if prompt:
+ content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
+ temp_messages = self.messages + [{"role": "user", "content": content}]
+ else:
+ temp_messages = self.messages
+
+ structured_output_span.set_attributes(
+ {
+ "gen_ai.system": "strands-agents",
+ "gen_ai.agent.name": self.name,
+ "gen_ai.agent.id": self.agent_id,
+ "gen_ai.operation.name": "execute_structured_output",
+ }
+ )
+ for message in temp_messages:
+ structured_output_span.add_event(
+ f"gen_ai.{message['role']}.message",
+ attributes={"role": message["role"], "content": serialize(message["content"])},
+ )
+ if self.system_prompt:
+ structured_output_span.add_event(
+ "gen_ai.system.message",
+ attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
+ )
+ events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
+ async for event in events:
+ if "callback" in event:
+ self.callback_handler(**cast(dict, event["callback"]))
+ structured_output_span.add_event(
+ "gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
+ )
+ return event["output"]
+
+ finally:
+ self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index c27243dfe..fdce7c368 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -980,6 +980,14 @@ def test_agent_callback_handler_custom_handler_used():
def test_agent_structured_output(agent, system_prompt, user, agenerator):
+ # Setup mock tracer and span
+ mock_strands_tracer = unittest.mock.MagicMock()
+ mock_otel_tracer = unittest.mock.MagicMock()
+ mock_span = unittest.mock.MagicMock()
+ mock_strands_tracer.tracer = mock_otel_tracer
+ mock_otel_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span
+ agent.tracer = mock_strands_tracer
+
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))
prompt = "Jane Doe is 30 years old and her email is jane@doe.com"
@@ -999,8 +1007,34 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator):
type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt
)
+ mock_span.set_attributes.assert_called_once_with(
+ {
+ "gen_ai.system": "strands-agents",
+ "gen_ai.agent.name": "Strands Agents",
+ "gen_ai.agent.id": "default",
+ "gen_ai.operation.name": "execute_structured_output",
+ }
+ )
+
+ mock_span.add_event.assert_any_call(
+ "gen_ai.user.message",
+ attributes={"role": "user", "content": '[{"text": "Jane Doe is 30 years old and her email is jane@doe.com"}]'},
+ )
+
+ mock_span.add_event.assert_called_with(
+ "gen_ai.choice",
+ attributes={"message": json.dumps(user.model_dump())},
+ )
+
def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, agenerator):
+ # Setup mock tracer and span
+ mock_strands_tracer = unittest.mock.MagicMock()
+ mock_otel_tracer = unittest.mock.MagicMock()
+ mock_span = unittest.mock.MagicMock()
+ mock_strands_tracer.tracer = mock_otel_tracer
+ mock_otel_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span
+ agent.tracer = mock_strands_tracer
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))
prompt = [
@@ -1030,6 +1064,11 @@ def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, a
type(user), [{"role": "user", "content": prompt}], system_prompt=system_prompt
)
+ mock_span.add_event.assert_called_with(
+ "gen_ai.choice",
+ attributes={"message": json.dumps(user.model_dump())},
+ )
+
@pytest.mark.asyncio
async def test_agent_structured_output_in_async_context(agent, user, agenerator):
From 1c7257bc9e2356d025c5fa77f6a3b1e959809964 Mon Sep 17 00:00:00 2001
From: Patrick Gray
Date: Thu, 14 Aug 2025 10:32:58 -0400
Subject: [PATCH 032/104] litellm - set 1.73.1 as minimum version (#668)
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index d4a4b79dc..487b26691 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -68,7 +68,7 @@ docs = [
"sphinx-autodoc-typehints>=1.12.0,<2.0.0",
]
litellm = [
- "litellm>=1.72.6,<1.73.0",
+ "litellm>=1.73.1,<2.0.0",
]
llamaapi = [
"llama-api-client>=0.1.0,<1.0.0",
From 606f65756668274d3acf2600b76df10745a08f1f Mon Sep 17 00:00:00 2001
From: Dean Schmigelski
Date: Thu, 14 Aug 2025 14:49:08 -0400
Subject: [PATCH 033/104] feat: expose tool_use and agent through ToolContext
to decorated tools (#557)
---
src/strands/__init__.py | 3 +-
src/strands/tools/decorator.py | 82 +++++++++--
src/strands/types/tools.py | 27 +++-
tests/strands/tools/test_decorator.py | 159 ++++++++++++++++++++-
tests_integ/test_tool_context_injection.py | 56 ++++++++
5 files changed, 312 insertions(+), 15 deletions(-)
create mode 100644 tests_integ/test_tool_context_injection.py
diff --git a/src/strands/__init__.py b/src/strands/__init__.py
index e9f9e9cd8..ae784a58f 100644
--- a/src/strands/__init__.py
+++ b/src/strands/__init__.py
@@ -3,5 +3,6 @@
from . import agent, models, telemetry, types
from .agent.agent import Agent
from .tools.decorator import tool
+from .types.tools import ToolContext
-__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry"]
+__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry", "ToolContext"]
diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py
index 5ec324b68..75abac9ed 100644
--- a/src/strands/tools/decorator.py
+++ b/src/strands/tools/decorator.py
@@ -61,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
from pydantic import BaseModel, Field, create_model
from typing_extensions import override
-from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolSpec, ToolUse
+from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolSpec, ToolUse
logger = logging.getLogger(__name__)
@@ -84,16 +84,18 @@ class FunctionToolMetadata:
validate tool usage.
"""
- def __init__(self, func: Callable[..., Any]) -> None:
+ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -> None:
"""Initialize with the function to process.
Args:
func: The function to extract metadata from.
Can be a standalone function or a class method.
+ context_param: Name of the context parameter to inject, if any.
"""
self.func = func
self.signature = inspect.signature(func)
self.type_hints = get_type_hints(func)
+ self._context_param = context_param
# Parse the docstring with docstring_parser
doc_str = inspect.getdoc(func) or ""
@@ -113,7 +115,7 @@ def _create_input_model(self) -> Type[BaseModel]:
This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can
validate input data before passing it to the function.
- Special parameters like 'self', 'cls', and 'agent' are excluded from the model.
+ Special parameters that can be automatically injected are excluded from the model.
Returns:
A Pydantic BaseModel class customized for the function's parameters.
@@ -121,8 +123,8 @@ def _create_input_model(self) -> Type[BaseModel]:
field_definitions: dict[str, Any] = {}
for name, param in self.signature.parameters.items():
- # Skip special parameters
- if name in ("self", "cls", "agent"):
+ # Skip parameters that will be automatically injected
+ if self._is_special_parameter(name):
continue
# Get parameter type and default
@@ -252,6 +254,49 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
error_msg = str(e)
raise ValueError(f"Validation failed for input parameters: {error_msg}") from e
+ def inject_special_parameters(
+ self, validated_input: dict[str, Any], tool_use: ToolUse, invocation_state: dict[str, Any]
+ ) -> None:
+ """Inject special framework-provided parameters into the validated input.
+
+ This method automatically provides framework-level context to tools that request it
+ through their function signature.
+
+ Args:
+ validated_input: The validated input parameters (modified in place).
+ tool_use: The tool use request containing tool invocation details.
+ invocation_state: Context for the tool invocation, including agent state.
+ """
+ if self._context_param and self._context_param in self.signature.parameters:
+ tool_context = ToolContext(tool_use=tool_use, agent=invocation_state["agent"])
+ validated_input[self._context_param] = tool_context
+
+ # Inject agent if requested (backward compatibility)
+ if "agent" in self.signature.parameters and "agent" in invocation_state:
+ validated_input["agent"] = invocation_state["agent"]
+
+ def _is_special_parameter(self, param_name: str) -> bool:
+ """Check if a parameter should be automatically injected by the framework or is a standard Python method param.
+
+ Special parameters include:
+ - Standard Python method parameters: self, cls
+ - Framework-provided context parameters: agent, and configurable context parameter (defaults to tool_context)
+
+ Args:
+ param_name: The name of the parameter to check.
+
+ Returns:
+ True if the parameter should be excluded from input validation and
+ handled specially during tool execution.
+ """
+ special_params = {"self", "cls", "agent"}
+
+ # Add context parameter if configured
+ if self._context_param:
+ special_params.add(self._context_param)
+
+ return param_name in special_params
+
P = ParamSpec("P") # Captures all parameters
R = TypeVar("R") # Return type
@@ -402,9 +447,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
# Validate input against the Pydantic model
validated_input = self._metadata.validate_input(tool_input)
- # Pass along the agent if provided and expected by the function
- if "agent" in invocation_state and "agent" in self._metadata.signature.parameters:
- validated_input["agent"] = invocation_state.get("agent")
+ # Inject special framework-provided parameters
+ self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state)
# "Too few arguments" expected, hence the type ignore
if inspect.iscoroutinefunction(self._tool_func):
@@ -474,6 +518,7 @@ def tool(
description: Optional[str] = None,
inputSchema: Optional[JSONSchema] = None,
name: Optional[str] = None,
+ context: bool | str = False,
) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ...
# Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the
# call site, but the actual implementation handles that and it's not representable via the type-system
@@ -482,6 +527,7 @@ def tool( # type: ignore
description: Optional[str] = None,
inputSchema: Optional[JSONSchema] = None,
name: Optional[str] = None,
+ context: bool | str = False,
) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]:
"""Decorator that transforms a Python function into a Strands tool.
@@ -507,6 +553,9 @@ def tool( # type: ignore
description: Optional custom description to override the function's docstring.
inputSchema: Optional custom JSON schema to override the automatically generated schema.
name: Optional custom name to override the function's name.
+ context: When provided, places an object in the designated parameter. If True, the param name
+ defaults to 'tool_context', or if an override is needed, set context equal to a string to designate
+ the param name.
Returns:
An AgentTool that also mimics the original function when invoked
@@ -536,15 +585,24 @@ def my_tool(name: str, count: int = 1) -> str:
Example with parameters:
```python
- @tool(name="custom_tool", description="A tool with a custom name and description")
- def my_tool(name: str, count: int = 1) -> str:
- return f"Processed {name} {count} times"
+ @tool(name="custom_tool", description="A tool with a custom name and description", context=True)
+ def my_tool(name: str, count: int = 1, tool_context: ToolContext) -> str:
+ tool_id = tool_context["tool_use"]["toolUseId"]
+ return f"Processed {name} {count} times with tool ID {tool_id}"
```
"""
def decorator(f: T) -> "DecoratedFunctionTool[P, R]":
+ # Resolve context parameter name
+ if isinstance(context, bool):
+ context_param = "tool_context" if context else None
+ else:
+ context_param = context.strip()
+ if not context_param:
+ raise ValueError("Context parameter name cannot be empty")
+
# Create function tool metadata
- tool_meta = FunctionToolMetadata(f)
+ tool_meta = FunctionToolMetadata(f, context_param)
tool_spec = tool_meta.extract_metadata()
if name is not None:
tool_spec["name"] = name
diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py
index 533e5529c..bb7c874f6 100644
--- a/src/strands/types/tools.py
+++ b/src/strands/types/tools.py
@@ -6,12 +6,16 @@
"""
from abc import ABC, abstractmethod
-from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union
from typing_extensions import TypedDict
from .media import DocumentContent, ImageContent
+if TYPE_CHECKING:
+ from .. import Agent
+
JSONSchema = dict
"""Type alias for JSON Schema dictionaries."""
@@ -117,6 +121,27 @@ class ToolChoiceTool(TypedDict):
name: str
+@dataclass
+class ToolContext:
+ """Context object containing framework-provided data for decorated tools.
+
+ This object provides access to framework-level information that may be useful
+ for tool implementations.
+
+ Attributes:
+ tool_use: The complete ToolUse object containing tool invocation details.
+ agent: The Agent instance executing this tool, providing access to conversation history,
+ model configuration, and other agent state.
+
+ Note:
+ This class is intended to be instantiated by the SDK. Direct construction by users
+ is not supported and may break in future versions as new fields are added.
+ """
+
+ tool_use: ToolUse
+ agent: "Agent"
+
+
ToolChoice = Union[
dict[Literal["auto"], ToolChoiceAuto],
dict[Literal["any"], ToolChoiceAny],
diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py
index 52a9282e0..246879da7 100644
--- a/tests/strands/tools/test_decorator.py
+++ b/tests/strands/tools/test_decorator.py
@@ -8,7 +8,8 @@
import pytest
import strands
-from strands.types.tools import ToolUse
+from strands import Agent
+from strands.types.tools import AgentTool, ToolContext, ToolUse
@pytest.fixture(scope="module")
@@ -1036,3 +1037,159 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]
result = (await alist(stream))[-1]
assert result["status"] == "success"
assert "NoneType: None" in result["content"][0]["text"]
+
+
+async def _run_context_injection_test(context_tool: AgentTool):
+ """Common test logic for context injection tests."""
+ tool: AgentTool = context_tool
+ generator = tool.stream(
+ tool_use={
+ "toolUseId": "test-id",
+ "name": "context_tool",
+ "input": {
+ "message": "some_message" # note that we do not include agent nor tool context
+ },
+ },
+ invocation_state={
+ "agent": Agent(name="test_agent"),
+ },
+ )
+ tool_results = [value async for value in generator]
+
+ assert len(tool_results) == 1
+ tool_result = tool_results[0]
+
+ assert tool_result == {
+ "status": "success",
+ "content": [
+ {"text": "Tool 'context_tool' (ID: test-id)"},
+ {"text": "injected agent 'test_agent' processed: some_message"},
+ {"text": "context agent 'test_agent'"}
+ ],
+ "toolUseId": "test-id",
+ }
+
+
+@pytest.mark.asyncio
+async def test_tool_context_injection_default():
+ """Test that ToolContext is properly injected with default parameter name (tool_context)."""
+
+ @strands.tool(context=True)
+ def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict:
+ """Tool that uses ToolContext to access tool_use_id."""
+ tool_use_id = tool_context.tool_use["toolUseId"]
+ tool_name = tool_context.tool_use["name"]
+ agent_from_tool_context = tool_context.agent
+
+ return {
+ "status": "success",
+ "content": [
+ {"text": f"Tool '{tool_name}' (ID: {tool_use_id})"},
+ {"text": f"injected agent '{agent.name}' processed: {message}"},
+ {"text": f"context agent '{agent_from_tool_context.name}'"},
+ ],
+ }
+
+ await _run_context_injection_test(context_tool)
+
+
+@pytest.mark.asyncio
+async def test_tool_context_injection_custom_name():
+ """Test that ToolContext is properly injected with custom parameter name."""
+
+ @strands.tool(context="custom_context_name")
+ def context_tool(message: str, agent: Agent, custom_context_name: ToolContext) -> dict:
+ """Tool that uses ToolContext to access tool_use_id."""
+ tool_use_id = custom_context_name.tool_use["toolUseId"]
+ tool_name = custom_context_name.tool_use["name"]
+ agent_from_tool_context = custom_context_name.agent
+
+ return {
+ "status": "success",
+ "content": [
+ {"text": f"Tool '{tool_name}' (ID: {tool_use_id})"},
+ {"text": f"injected agent '{agent.name}' processed: {message}"},
+ {"text": f"context agent '{agent_from_tool_context.name}'"},
+ ],
+ }
+
+ await _run_context_injection_test(context_tool)
+
+
+@pytest.mark.asyncio
+async def test_tool_context_injection_disabled_missing_parameter():
+ """Test that when context=False, missing tool_context parameter causes validation error."""
+
+ @strands.tool(context=False)
+ def context_tool(message: str, agent: Agent, tool_context: str) -> dict:
+ """Tool that expects tool_context as a regular string parameter."""
+ return {
+ "status": "success",
+ "content": [
+ {"text": f"Message: {message}"},
+ {"text": f"Agent: {agent.name}"},
+ {"text": f"Tool context string: {tool_context}"},
+ ],
+ }
+
+ # Verify that missing tool_context parameter causes validation error
+ tool: AgentTool = context_tool
+ generator = tool.stream(
+ tool_use={
+ "toolUseId": "test-id",
+ "name": "context_tool",
+ "input": {
+ "message": "some_message"
+ # Missing tool_context parameter - should cause validation error instead of being auto injected
+ },
+ },
+ invocation_state={
+ "agent": Agent(name="test_agent"),
+ },
+ )
+ tool_results = [value async for value in generator]
+
+ assert len(tool_results) == 1
+ tool_result = tool_results[0]
+
+ # Should get a validation error because tool_context is required but not provided
+ assert tool_result["status"] == "error"
+ assert "tool_context" in tool_result["content"][0]["text"].lower()
+ assert "validation" in tool_result["content"][0]["text"].lower()
+
+
+@pytest.mark.asyncio
+async def test_tool_context_injection_disabled_string_parameter():
+ """Test that when context=False, tool_context can be passed as a string parameter."""
+
+ @strands.tool(context=False)
+ def context_tool(message: str, agent: Agent, tool_context: str) -> str:
+ """Tool that expects tool_context as a regular string parameter."""
+ return "success"
+
+ # Verify that providing tool_context as a string works correctly
+ tool: AgentTool = context_tool
+ generator = tool.stream(
+ tool_use={
+ "toolUseId": "test-id-2",
+ "name": "context_tool",
+ "input": {
+ "message": "some_message",
+ "tool_context": "my_custom_context_string"
+ },
+ },
+ invocation_state={
+ "agent": Agent(name="test_agent"),
+ },
+ )
+ tool_results = [value async for value in generator]
+
+ assert len(tool_results) == 1
+ tool_result = tool_results[0]
+
+ # Should succeed with the string parameter
+ assert tool_result == {
+ "status": "success",
+ "content": [{"text": "success"}],
+ "toolUseId": "test-id-2",
+ }
diff --git a/tests_integ/test_tool_context_injection.py b/tests_integ/test_tool_context_injection.py
new file mode 100644
index 000000000..3098604f1
--- /dev/null
+++ b/tests_integ/test_tool_context_injection.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python3
+"""
+Integration test for ToolContext functionality with real agent interactions.
+"""
+
+from strands import Agent, ToolContext, tool
+from strands.types.tools import ToolResult
+
+
+@tool(context="custom_context_field")
+def good_story(message: str, custom_context_field: ToolContext) -> dict:
+ """Tool that writes a good story"""
+ tool_use_id = custom_context_field.tool_use["toolUseId"]
+ return {
+ "status": "success",
+ "content": [{"text": f"Context tool processed with ID: {tool_use_id}"}],
+ }
+
+
+@tool(context=True)
+def bad_story(message: str, tool_context: ToolContext) -> dict:
+ """Tool that writes a bad story"""
+ tool_use_id = tool_context.tool_use["toolUseId"]
+ return {
+ "status": "success",
+ "content": [{"text": f"Context tool processed with ID: {tool_use_id}"}],
+ }
+
+
+def _validate_tool_result_content(agent: Agent):
+ first_tool_result: ToolResult = [
+ block["toolResult"] for message in agent.messages for block in message["content"] if "toolResult" in block
+ ][0]
+
+ assert first_tool_result["status"] == "success"
+ assert (
+ first_tool_result["content"][0]["text"] == f"Context tool processed with ID: {first_tool_result['toolUseId']}"
+ )
+
+
+def test_strands_context_integration_context_true():
+ """Test ToolContext functionality with real agent interactions."""
+
+ agent = Agent(tools=[good_story])
+ agent("using a tool, write a good story")
+
+ _validate_tool_result_content(agent)
+
+
+def test_strands_context_integration_context_custom():
+ """Test ToolContext functionality with real agent interactions."""
+
+ agent = Agent(tools=[bad_story])
+ agent("using a tool, write a bad story")
+
+ _validate_tool_result_content(agent)
From 8c63d75ecf9c246110d297c109bf204839978152 Mon Sep 17 00:00:00 2001
From: Patrick Gray
Date: Fri, 15 Aug 2025 17:50:42 -0400
Subject: [PATCH 034/104] session manager - prevent file path injection (#680)
---
src/strands/_identifier.py | 30 +
src/strands/agent/agent.py | 6 +-
src/strands/session/file_session_manager.py | 27 +-
src/strands/session/s3_session_manager.py | 23 +-
tests/strands/agent/test_agent.py | 12 +
.../session/test_file_session_manager.py | 604 +++++++++---------
.../session/test_s3_session_manager.py | 24 +
tests/strands/test_identifier.py | 17 +
tests/strands/tools/test_decorator.py | 9 +-
9 files changed, 452 insertions(+), 300 deletions(-)
create mode 100644 src/strands/_identifier.py
create mode 100644 tests/strands/test_identifier.py
diff --git a/src/strands/_identifier.py b/src/strands/_identifier.py
new file mode 100644
index 000000000..e8b12635c
--- /dev/null
+++ b/src/strands/_identifier.py
@@ -0,0 +1,30 @@
+"""Strands identifier utilities."""
+
+import enum
+import os
+
+
+class Identifier(enum.Enum):
+ """Strands identifier types."""
+
+ AGENT = "agent"
+ SESSION = "session"
+
+
+def validate(id_: str, type_: Identifier) -> str:
+ """Validate strands id.
+
+ Args:
+ id_: Id to validate.
+ type_: Type of the identifier (e.g., session id, agent id, etc.)
+
+ Returns:
+ Validated id.
+
+ Raises:
+ ValueError: If id contains path separators.
+ """
+ if os.path.basename(id_) != id_:
+ raise ValueError(f"{type_.value}_id={id_} | id cannot contain path separators")
+
+ return id_
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index 43b5cbf8c..38e687af2 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -19,6 +19,7 @@
from opentelemetry import trace as trace_api
from pydantic import BaseModel
+from .. import _identifier
from ..event_loop.event_loop import event_loop_cycle, run_tool
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
from ..hooks import (
@@ -249,12 +250,15 @@ def __init__(
Defaults to None.
session_manager: Manager for handling agent sessions including conversation history and state.
If provided, enables session-based persistence and state management.
+
+ Raises:
+ ValueError: If agent id contains path separators.
"""
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
self.messages = messages if messages is not None else []
self.system_prompt = system_prompt
- self.agent_id = agent_id or _DEFAULT_AGENT_ID
+ self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
self.name = name or _DEFAULT_AGENT_NAME
self.description = description
diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py
index fec2f0761..9df86e17a 100644
--- a/src/strands/session/file_session_manager.py
+++ b/src/strands/session/file_session_manager.py
@@ -7,6 +7,7 @@
import tempfile
from typing import Any, Optional, cast
+from .. import _identifier
from ..types.exceptions import SessionException
from ..types.session import Session, SessionAgent, SessionMessage
from .repository_session_manager import RepositorySessionManager
@@ -40,8 +41,9 @@ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs:
"""Initialize FileSession with filesystem storage.
Args:
- session_id: ID for the session
- storage_dir: Directory for local filesystem storage (defaults to temp dir)
+ session_id: ID for the session.
+ ID is not allowed to contain path separators (e.g., a/b).
+ storage_dir: Directory for local filesystem storage (defaults to temp dir).
**kwargs: Additional keyword arguments for future extensibility.
"""
self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions")
@@ -50,12 +52,29 @@ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs:
super().__init__(session_id=session_id, session_repository=self)
def _get_session_path(self, session_id: str) -> str:
- """Get session directory path."""
+ """Get session directory path.
+
+ Args:
+ session_id: ID for the session.
+
+ Raises:
+ ValueError: If session id contains a path separator.
+ """
+ session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION)
return os.path.join(self.storage_dir, f"{SESSION_PREFIX}{session_id}")
def _get_agent_path(self, session_id: str, agent_id: str) -> str:
- """Get agent directory path."""
+ """Get agent directory path.
+
+ Args:
+ session_id: ID for the session.
+ agent_id: ID for the agent.
+
+ Raises:
+ ValueError: If session id or agent id contains a path separator.
+ """
session_path = self._get_session_path(session_id)
+ agent_id = _identifier.validate(agent_id, _identifier.Identifier.AGENT)
return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}")
def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str:
diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py
index 0cc0a68c1..d15e6e3bd 100644
--- a/src/strands/session/s3_session_manager.py
+++ b/src/strands/session/s3_session_manager.py
@@ -8,6 +8,7 @@
from botocore.config import Config as BotocoreConfig
from botocore.exceptions import ClientError
+from .. import _identifier
from ..types.exceptions import SessionException
from ..types.session import Session, SessionAgent, SessionMessage
from .repository_session_manager import RepositorySessionManager
@@ -51,6 +52,7 @@ def __init__(
Args:
session_id: ID for the session
+ ID is not allowed to contain path separators (e.g., a/b).
bucket: S3 bucket name (required)
prefix: S3 key prefix for storage organization
boto_session: Optional boto3 session
@@ -79,12 +81,29 @@ def __init__(
super().__init__(session_id=session_id, session_repository=self)
def _get_session_path(self, session_id: str) -> str:
- """Get session S3 prefix."""
+ """Get session S3 prefix.
+
+ Args:
+ session_id: ID for the session.
+
+ Raises:
+ ValueError: If session id contains a path separator.
+ """
+ session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION)
return f"{self.prefix}/{SESSION_PREFIX}{session_id}/"
def _get_agent_path(self, session_id: str, agent_id: str) -> str:
- """Get agent S3 prefix."""
+ """Get agent S3 prefix.
+
+ Args:
+ session_id: ID for the session.
+ agent_id: ID for the agent.
+
+ Raises:
+ ValueError: If session id or agent id contains a path separator.
+ """
session_path = self._get_session_path(session_id)
+ agent_id = _identifier.validate(agent_id, _identifier.Identifier.AGENT)
return f"{session_path}agents/{AGENT_PREFIX}{agent_id}/"
def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str:
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index fdce7c368..ca66ca2bf 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -250,6 +250,18 @@ def test_agent__init__deeply_nested_tools(tool_decorated, tool_module, tool_impo
assert tru_tool_names == exp_tool_names
+@pytest.mark.parametrize(
+ "agent_id",
+ [
+ "a/../b",
+ "a/b",
+ ],
+)
+def test_agent__init__invalid_id(agent_id):
+ with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"):
+ Agent(agent_id=agent_id)
+
+
def test_agent__call__(
mock_model,
system_prompt,
diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py
index f9fc3ba94..a89222b7e 100644
--- a/tests/strands/session/test_file_session_manager.py
+++ b/tests/strands/session/test_file_session_manager.py
@@ -53,310 +53,340 @@ def sample_message():
)
-class TestFileSessionManagerSessionOperations:
- """Tests for session operations."""
-
- def test_create_session(self, file_manager, sample_session):
- """Test creating a session."""
- file_manager.create_session(sample_session)
-
- # Verify directory structure created
- session_path = file_manager._get_session_path(sample_session.session_id)
- assert os.path.exists(session_path)
-
- # Verify session file created
- session_file = os.path.join(session_path, "session.json")
- assert os.path.exists(session_file)
-
- # Verify content
- with open(session_file, "r") as f:
- data = json.load(f)
- assert data["session_id"] == sample_session.session_id
- assert data["session_type"] == sample_session.session_type
-
- def test_read_session(self, file_manager, sample_session):
- """Test reading an existing session."""
- # Create session first
- file_manager.create_session(sample_session)
-
- # Read it back
- result = file_manager.read_session(sample_session.session_id)
-
- assert result.session_id == sample_session.session_id
- assert result.session_type == sample_session.session_type
-
- def test_read_nonexistent_session(self, file_manager):
- """Test reading a session that doesn't exist."""
- result = file_manager.read_session("nonexistent-session")
- assert result is None
-
- def test_delete_session(self, file_manager, sample_session):
- """Test deleting a session."""
- # Create session first
- file_manager.create_session(sample_session)
- session_path = file_manager._get_session_path(sample_session.session_id)
- assert os.path.exists(session_path)
-
- # Delete session
- file_manager.delete_session(sample_session.session_id)
-
- # Verify deletion
- assert not os.path.exists(session_path)
-
- def test_delete_nonexistent_session(self, file_manager):
- """Test deleting a session that doesn't exist."""
- # Should raise an error according to the implementation
- with pytest.raises(SessionException, match="does not exist"):
- file_manager.delete_session("nonexistent-session")
-
-
-class TestFileSessionManagerAgentOperations:
- """Tests for agent operations."""
-
- def test_create_agent(self, file_manager, sample_session, sample_agent):
- """Test creating an agent in a session."""
- # Create session first
- file_manager.create_session(sample_session)
-
- # Create agent
- file_manager.create_agent(sample_session.session_id, sample_agent)
-
- # Verify directory structure
- agent_path = file_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id)
- assert os.path.exists(agent_path)
-
- # Verify agent file
- agent_file = os.path.join(agent_path, "agent.json")
- assert os.path.exists(agent_file)
-
- # Verify content
- with open(agent_file, "r") as f:
- data = json.load(f)
- assert data["agent_id"] == sample_agent.agent_id
- assert data["state"] == sample_agent.state
-
- def test_read_agent(self, file_manager, sample_session, sample_agent):
- """Test reading an agent from a session."""
- # Create session and agent
- file_manager.create_session(sample_session)
- file_manager.create_agent(sample_session.session_id, sample_agent)
-
- # Read agent
- result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id)
-
- assert result.agent_id == sample_agent.agent_id
- assert result.state == sample_agent.state
-
- def test_read_nonexistent_agent(self, file_manager, sample_session):
- """Test reading an agent that doesn't exist."""
- result = file_manager.read_agent(sample_session.session_id, "nonexistent_agent")
- assert result is None
-
- def test_update_agent(self, file_manager, sample_session, sample_agent):
- """Test updating an agent."""
- # Create session and agent
- file_manager.create_session(sample_session)
- file_manager.create_agent(sample_session.session_id, sample_agent)
-
- # Update agent
- sample_agent.state = {"updated": "value"}
+def test_create_session(file_manager, sample_session):
+ """Test creating a session."""
+ file_manager.create_session(sample_session)
+
+ # Verify directory structure created
+ session_path = file_manager._get_session_path(sample_session.session_id)
+ assert os.path.exists(session_path)
+
+ # Verify session file created
+ session_file = os.path.join(session_path, "session.json")
+ assert os.path.exists(session_file)
+
+ # Verify content
+ with open(session_file, "r") as f:
+ data = json.load(f)
+ assert data["session_id"] == sample_session.session_id
+ assert data["session_type"] == sample_session.session_type
+
+
+def test_read_session(file_manager, sample_session):
+ """Test reading an existing session."""
+ # Create session first
+ file_manager.create_session(sample_session)
+
+ # Read it back
+ result = file_manager.read_session(sample_session.session_id)
+
+ assert result.session_id == sample_session.session_id
+ assert result.session_type == sample_session.session_type
+
+
+def test_read_nonexistent_session(file_manager):
+ """Test reading a session that doesn't exist."""
+ result = file_manager.read_session("nonexistent-session")
+ assert result is None
+
+
+def test_delete_session(file_manager, sample_session):
+ """Test deleting a session."""
+ # Create session first
+ file_manager.create_session(sample_session)
+ session_path = file_manager._get_session_path(sample_session.session_id)
+ assert os.path.exists(session_path)
+
+ # Delete session
+ file_manager.delete_session(sample_session.session_id)
+
+ # Verify deletion
+ assert not os.path.exists(session_path)
+
+
+def test_delete_nonexistent_session(file_manager):
+ """Test deleting a session that doesn't exist."""
+ # Should raise an error according to the implementation
+ with pytest.raises(SessionException, match="does not exist"):
+ file_manager.delete_session("nonexistent-session")
+
+
+def test_create_agent(file_manager, sample_session, sample_agent):
+ """Test creating an agent in a session."""
+ # Create session first
+ file_manager.create_session(sample_session)
+
+ # Create agent
+ file_manager.create_agent(sample_session.session_id, sample_agent)
+
+ # Verify directory structure
+ agent_path = file_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id)
+ assert os.path.exists(agent_path)
+
+ # Verify agent file
+ agent_file = os.path.join(agent_path, "agent.json")
+ assert os.path.exists(agent_file)
+
+ # Verify content
+ with open(agent_file, "r") as f:
+ data = json.load(f)
+ assert data["agent_id"] == sample_agent.agent_id
+ assert data["state"] == sample_agent.state
+
+
+def test_read_agent(file_manager, sample_session, sample_agent):
+ """Test reading an agent from a session."""
+ # Create session and agent
+ file_manager.create_session(sample_session)
+ file_manager.create_agent(sample_session.session_id, sample_agent)
+
+ # Read agent
+ result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id)
+
+ assert result.agent_id == sample_agent.agent_id
+ assert result.state == sample_agent.state
+
+
+def test_read_nonexistent_agent(file_manager, sample_session):
+ """Test reading an agent that doesn't exist."""
+ result = file_manager.read_agent(sample_session.session_id, "nonexistent_agent")
+ assert result is None
+
+
+def test_update_agent(file_manager, sample_session, sample_agent):
+ """Test updating an agent."""
+ # Create session and agent
+ file_manager.create_session(sample_session)
+ file_manager.create_agent(sample_session.session_id, sample_agent)
+
+ # Update agent
+ sample_agent.state = {"updated": "value"}
+ file_manager.update_agent(sample_session.session_id, sample_agent)
+
+ # Verify update
+ result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id)
+ assert result.state == {"updated": "value"}
+
+
+def test_update_nonexistent_agent(file_manager, sample_session, sample_agent):
+ """Test updating an agent."""
+ # Create session and agent
+ file_manager.create_session(sample_session)
+
+ # Update agent
+ with pytest.raises(SessionException):
file_manager.update_agent(sample_session.session_id, sample_agent)
- # Verify update
- result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id)
- assert result.state == {"updated": "value"}
- def test_update_nonexistent_agent(self, file_manager, sample_session, sample_agent):
- """Test updating an agent."""
- # Create session and agent
- file_manager.create_session(sample_session)
+def test_create_message(file_manager, sample_session, sample_agent, sample_message):
+ """Test creating a message for an agent."""
+ # Create session and agent
+ file_manager.create_session(sample_session)
+ file_manager.create_agent(sample_session.session_id, sample_agent)
- # Update agent
- with pytest.raises(SessionException):
- file_manager.update_agent(sample_session.session_id, sample_agent)
+ # Create message
+ file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message)
+
+ # Verify message file
+ message_path = file_manager._get_message_path(
+ sample_session.session_id, sample_agent.agent_id, sample_message.message_id
+ )
+ assert os.path.exists(message_path)
+ # Verify content
+ with open(message_path, "r") as f:
+ data = json.load(f)
+ assert data["message_id"] == sample_message.message_id
-class TestFileSessionManagerMessageOperations:
- """Tests for message operations."""
- def test_create_message(self, file_manager, sample_session, sample_agent, sample_message):
- """Test creating a message for an agent."""
- # Create session and agent
- file_manager.create_session(sample_session)
- file_manager.create_agent(sample_session.session_id, sample_agent)
+def test_read_message(file_manager, sample_session, sample_agent, sample_message):
+ """Test reading a message."""
+ # Create session, agent, and message
+ file_manager.create_session(sample_session)
+ file_manager.create_agent(sample_session.session_id, sample_agent)
+ file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message)
- # Create message
- file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message)
+ # Create multiple messages when reading
+ sample_message.message_id = sample_message.message_id + 1
+ file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message)
- # Verify message file
- message_path = file_manager._get_message_path(
- sample_session.session_id, sample_agent.agent_id, sample_message.message_id
+ # Read message
+ result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id)
+
+ assert result.message_id == sample_message.message_id
+ assert result.message["role"] == sample_message.message["role"]
+ assert result.message["content"] == sample_message.message["content"]
+
+
+def test_read_messages_with_new_agent(file_manager, sample_session, sample_agent):
+ """Test reading a message with with a new agent."""
+ # Create session and agent
+ file_manager.create_session(sample_session)
+ file_manager.create_agent(sample_session.session_id, sample_agent)
+
+ result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message")
+
+ assert result is None
+
+
+def test_read_nonexistent_message(file_manager, sample_session, sample_agent):
+ """Test reading a message that doesnt exist."""
+ result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message")
+ assert result is None
+
+
+def test_list_messages_all(file_manager, sample_session, sample_agent):
+ """Test listing all messages for an agent."""
+ # Create session and agent
+ file_manager.create_session(sample_session)
+ file_manager.create_agent(sample_session.session_id, sample_agent)
+
+ # Create multiple messages
+ messages = []
+ for i in range(5):
+ message = SessionMessage(
+ message={
+ "role": "user",
+ "content": [ContentBlock(text=f"Message {i}")],
+ },
+ message_id=i,
)
- assert os.path.exists(message_path)
-
- # Verify content
- with open(message_path, "r") as f:
- data = json.load(f)
- assert data["message_id"] == sample_message.message_id
-
- def test_read_message(self, file_manager, sample_session, sample_agent, sample_message):
- """Test reading a message."""
- # Create session, agent, and message
- file_manager.create_session(sample_session)
- file_manager.create_agent(sample_session.session_id, sample_agent)
- file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message)
-
- # Create multiple messages when reading
- sample_message.message_id = sample_message.message_id + 1
- file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message)
-
- # Read message
- result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id)
-
- assert result.message_id == sample_message.message_id
- assert result.message["role"] == sample_message.message["role"]
- assert result.message["content"] == sample_message.message["content"]
-
- def test_read_messages_with_new_agent(self, file_manager, sample_session, sample_agent):
- """Test reading a message with with a new agent."""
- # Create session and agent
- file_manager.create_session(sample_session)
- file_manager.create_agent(sample_session.session_id, sample_agent)
-
- result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message")
-
- assert result is None
-
- def test_read_nonexistent_message(self, file_manager, sample_session, sample_agent):
- """Test reading a message that doesnt exist."""
- result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message")
- assert result is None
-
- def test_list_messages_all(self, file_manager, sample_session, sample_agent):
- """Test listing all messages for an agent."""
- # Create session and agent
- file_manager.create_session(sample_session)
- file_manager.create_agent(sample_session.session_id, sample_agent)
-
- # Create multiple messages
- messages = []
- for i in range(5):
- message = SessionMessage(
- message={
- "role": "user",
- "content": [ContentBlock(text=f"Message {i}")],
- },
- message_id=i,
- )
- messages.append(message)
- file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message)
-
- # List all messages
- result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id)
-
- assert len(result) == 5
-
- def test_list_messages_with_limit(self, file_manager, sample_session, sample_agent):
- """Test listing messages with limit."""
- # Create session and agent
- file_manager.create_session(sample_session)
- file_manager.create_agent(sample_session.session_id, sample_agent)
-
- # Create multiple messages
- for i in range(10):
- message = SessionMessage(
- message={
- "role": "user",
- "content": [ContentBlock(text=f"Message {i}")],
- },
- message_id=i,
- )
- file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message)
-
- # List with limit
- result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3)
-
- assert len(result) == 3
-
- def test_list_messages_with_offset(self, file_manager, sample_session, sample_agent):
- """Test listing messages with offset."""
- # Create session and agent
- file_manager.create_session(sample_session)
- file_manager.create_agent(sample_session.session_id, sample_agent)
-
- # Create multiple messages
- for i in range(10):
- message = SessionMessage(
- message={
- "role": "user",
- "content": [ContentBlock(text=f"Message {i}")],
- },
- message_id=i,
- )
- file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message)
-
- # List with offset
- result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5)
-
- assert len(result) == 5
-
- def test_list_messages_with_new_agent(self, file_manager, sample_session, sample_agent):
- """Test listing messages with new agent."""
- # Create session and agent
- file_manager.create_session(sample_session)
- file_manager.create_agent(sample_session.session_id, sample_agent)
-
- result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id)
-
- assert len(result) == 0
-
- def test_update_message(self, file_manager, sample_session, sample_agent, sample_message):
- """Test updating a message."""
- # Create session, agent, and message
- file_manager.create_session(sample_session)
- file_manager.create_agent(sample_session.session_id, sample_agent)
- file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message)
-
- # Update message
- sample_message.message["content"] = [ContentBlock(text="Updated content")]
- file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message)
+ messages.append(message)
+ file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message)
- # Verify update
- result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id)
- assert result.message["content"][0]["text"] == "Updated content"
+ # List all messages
+ result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id)
- def test_update_nonexistent_message(self, file_manager, sample_session, sample_agent, sample_message):
- """Test updating a message."""
- # Create session, agent, and message
- file_manager.create_session(sample_session)
- file_manager.create_agent(sample_session.session_id, sample_agent)
+ assert len(result) == 5
+
+
+def test_list_messages_with_limit(file_manager, sample_session, sample_agent):
+ """Test listing messages with limit."""
+ # Create session and agent
+ file_manager.create_session(sample_session)
+ file_manager.create_agent(sample_session.session_id, sample_agent)
+
+ # Create multiple messages
+ for i in range(10):
+ message = SessionMessage(
+ message={
+ "role": "user",
+ "content": [ContentBlock(text=f"Message {i}")],
+ },
+ message_id=i,
+ )
+ file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message)
+
+ # List with limit
+ result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3)
+
+ assert len(result) == 3
- # Update nonexistent message
- with pytest.raises(SessionException):
- file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message)
+def test_list_messages_with_offset(file_manager, sample_session, sample_agent):
+ """Test listing messages with offset."""
+ # Create session and agent
+ file_manager.create_session(sample_session)
+ file_manager.create_agent(sample_session.session_id, sample_agent)
-class TestFileSessionManagerErrorHandling:
- """Tests for error handling scenarios."""
+ # Create multiple messages
+ for i in range(10):
+ message = SessionMessage(
+ message={
+ "role": "user",
+ "content": [ContentBlock(text=f"Message {i}")],
+ },
+ message_id=i,
+ )
+ file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message)
+
+ # List with offset
+ result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5)
+
+ assert len(result) == 5
+
+
+def test_list_messages_with_new_agent(file_manager, sample_session, sample_agent):
+ """Test listing messages with new agent."""
+ # Create session and agent
+ file_manager.create_session(sample_session)
+ file_manager.create_agent(sample_session.session_id, sample_agent)
+
+ result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id)
+
+ assert len(result) == 0
+
+
+def test_update_message(file_manager, sample_session, sample_agent, sample_message):
+ """Test updating a message."""
+ # Create session, agent, and message
+ file_manager.create_session(sample_session)
+ file_manager.create_agent(sample_session.session_id, sample_agent)
+ file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message)
- def test_corrupted_json_file(self, file_manager, temp_dir):
- """Test handling of corrupted JSON files."""
- # Create a corrupted session file
- session_path = os.path.join(temp_dir, "session_test")
- os.makedirs(session_path, exist_ok=True)
- session_file = os.path.join(session_path, "session.json")
+ # Update message
+ sample_message.message["content"] = [ContentBlock(text="Updated content")]
+ file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message)
- with open(session_file, "w") as f:
- f.write("invalid json content")
+ # Verify update
+ result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id)
+ assert result.message["content"][0]["text"] == "Updated content"
- # Should raise SessionException
- with pytest.raises(SessionException, match="Invalid JSON"):
- file_manager._read_file(session_file)
- def test_permission_error_handling(self, file_manager):
- """Test handling of permission errors."""
- with patch("builtins.open", side_effect=PermissionError("Access denied")):
- session = Session(session_id="test", session_type=SessionType.AGENT)
+def test_update_nonexistent_message(file_manager, sample_session, sample_agent, sample_message):
+ """Test updating a message."""
+ # Create session, agent, and message
+ file_manager.create_session(sample_session)
+ file_manager.create_agent(sample_session.session_id, sample_agent)
- with pytest.raises(SessionException):
- file_manager.create_session(session)
+ # Update nonexistent message
+ with pytest.raises(SessionException):
+ file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message)
+
+
+def test_corrupted_json_file(file_manager, temp_dir):
+ """Test handling of corrupted JSON files."""
+ # Create a corrupted session file
+ session_path = os.path.join(temp_dir, "session_test")
+ os.makedirs(session_path, exist_ok=True)
+ session_file = os.path.join(session_path, "session.json")
+
+ with open(session_file, "w") as f:
+ f.write("invalid json content")
+
+ # Should raise SessionException
+ with pytest.raises(SessionException, match="Invalid JSON"):
+ file_manager._read_file(session_file)
+
+
+def test_permission_error_handling(file_manager):
+ """Test handling of permission errors."""
+ with patch("builtins.open", side_effect=PermissionError("Access denied")):
+ session = Session(session_id="test", session_type=SessionType.AGENT)
+
+ with pytest.raises(SessionException):
+ file_manager.create_session(session)
+
+
+@pytest.mark.parametrize(
+ "session_id",
+ [
+ "a/../b",
+ "a/b",
+ ],
+)
+def test__get_session_path_invalid_session_id(session_id, file_manager):
+ with pytest.raises(ValueError, match=f"session_id={session_id} | id cannot contain path separators"):
+ file_manager._get_session_path(session_id)
+
+
+@pytest.mark.parametrize(
+ "agent_id",
+ [
+ "a/../b",
+ "a/b",
+ ],
+)
+def test__get_agent_path_invalid_agent_id(agent_id, file_manager):
+ with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"):
+ file_manager._get_agent_path("session1", agent_id)
diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py
index fadd0db4b..71bff3050 100644
--- a/tests/strands/session/test_s3_session_manager.py
+++ b/tests/strands/session/test_s3_session_manager.py
@@ -332,3 +332,27 @@ def test_update_nonexistent_message(s3_manager, sample_session, sample_agent, sa
# Update message
with pytest.raises(SessionException):
s3_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message)
+
+
+@pytest.mark.parametrize(
+ "session_id",
+ [
+ "a/../b",
+ "a/b",
+ ],
+)
+def test__get_session_path_invalid_session_id(session_id, s3_manager):
+ with pytest.raises(ValueError, match=f"session_id={session_id} | id cannot contain path separators"):
+ s3_manager._get_session_path(session_id)
+
+
+@pytest.mark.parametrize(
+ "agent_id",
+ [
+ "a/../b",
+ "a/b",
+ ],
+)
+def test__get_agent_path_invalid_agent_id(agent_id, s3_manager):
+ with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"):
+ s3_manager._get_agent_path("session1", agent_id)
diff --git a/tests/strands/test_identifier.py b/tests/strands/test_identifier.py
new file mode 100644
index 000000000..df673baa8
--- /dev/null
+++ b/tests/strands/test_identifier.py
@@ -0,0 +1,17 @@
+import pytest
+
+from strands import _identifier
+
+
+@pytest.mark.parametrize("type_", list(_identifier.Identifier))
+def test_validate(type_):
+ tru_id = _identifier.validate("abc", type_)
+ exp_id = "abc"
+ assert tru_id == exp_id
+
+
+@pytest.mark.parametrize("type_", list(_identifier.Identifier))
+def test_validate_invalid(type_):
+ id_ = "a/../b"
+ with pytest.raises(ValueError, match=f"{type_.value}={id_} | id cannot contain path separators"):
+ _identifier.validate(id_, type_)
diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py
index 246879da7..e490c7bb0 100644
--- a/tests/strands/tools/test_decorator.py
+++ b/tests/strands/tools/test_decorator.py
@@ -1064,7 +1064,7 @@ async def _run_context_injection_test(context_tool: AgentTool):
"content": [
{"text": "Tool 'context_tool' (ID: test-id)"},
{"text": "injected agent 'test_agent' processed: some_message"},
- {"text": "context agent 'test_agent'"}
+ {"text": "context agent 'test_agent'"},
],
"toolUseId": "test-id",
}
@@ -1151,7 +1151,7 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> dict:
assert len(tool_results) == 1
tool_result = tool_results[0]
-
+
# Should get a validation error because tool_context is required but not provided
assert tool_result["status"] == "error"
assert "tool_context" in tool_result["content"][0]["text"].lower()
@@ -1173,10 +1173,7 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str:
tool_use={
"toolUseId": "test-id-2",
"name": "context_tool",
- "input": {
- "message": "some_message",
- "tool_context": "my_custom_context_string"
- },
+ "input": {"message": "some_message", "tool_context": "my_custom_context_string"},
},
invocation_state={
"agent": Agent(name="test_agent"),
From fbd598a0abea3d1b5a9781f7cdb5819ed81f51ca Mon Sep 17 00:00:00 2001
From: Clare Liguori
Date: Mon, 18 Aug 2025 06:21:15 -0700
Subject: [PATCH 035/104] fix: only set signature in message if signature was
provided by the model (#682)
---
src/strands/event_loop/streaming.py | 19 ++++++++++---------
tests/strands/event_loop/test_streaming.py | 15 +++++++++++++++
2 files changed, 25 insertions(+), 9 deletions(-)
diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py
index 74cadaf9e..f4048a65c 100644
--- a/src/strands/event_loop/streaming.py
+++ b/src/strands/event_loop/streaming.py
@@ -194,16 +194,18 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
state["text"] = ""
elif reasoning_text:
- content.append(
- {
- "reasoningContent": {
- "reasoningText": {
- "text": state["reasoningText"],
- "signature": state["signature"],
- }
+ content_block: ContentBlock = {
+ "reasoningContent": {
+ "reasoningText": {
+ "text": state["reasoningText"],
}
}
- )
+ }
+
+ if "signature" in state:
+ content_block["reasoningContent"]["reasoningText"]["signature"] = state["signature"]
+
+ content.append(content_block)
state["reasoningText"] = ""
return state
@@ -263,7 +265,6 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
"text": "",
"current_tool_use": {},
"reasoningText": "",
- "signature": "",
}
state["content"] = state["message"]["content"]
diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py
index 921fd91de..b1cc312c2 100644
--- a/tests/strands/event_loop/test_streaming.py
+++ b/tests/strands/event_loop/test_streaming.py
@@ -216,6 +216,21 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up
"signature": "123",
},
),
+ # Reasoning without signature
+ (
+ {
+ "content": [],
+ "current_tool_use": {},
+ "text": "",
+ "reasoningText": "test",
+ },
+ {
+ "content": [{"reasoningContent": {"reasoningText": {"text": "test"}}}],
+ "current_tool_use": {},
+ "text": "",
+ "reasoningText": "",
+ },
+ ),
# Empty
(
{
From ae74aa33ed9502c72f7d0f46757ec1c5a91fcb00 Mon Sep 17 00:00:00 2001
From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com>
Date: Mon, 18 Aug 2025 10:03:33 -0400
Subject: [PATCH 036/104] fix: Add openai dependency to sagemaker dependency
group (#678)
It depends on OpenAI and we a got a report about the need to install it explicitly
Co-authored-by: Mackenzie Zastrow
---
pyproject.toml | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 487b26691..6c0b6e3f7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -92,7 +92,9 @@ writer = [
sagemaker = [
"boto3>=1.26.0,<2.0.0",
"botocore>=1.29.0,<2.0.0",
- "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0"
+ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0",
+ # uses OpenAI as part of the implementation
+ "openai>=1.68.0,<2.0.0",
]
a2a = [
From 980a988f4cc3b580d37359f3646d2b603715ad69 Mon Sep 17 00:00:00 2001
From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com>
Date: Mon, 18 Aug 2025 15:10:58 -0400
Subject: [PATCH 037/104] Have [all] group reference the other optional
dependency groups by name (#674)
---
pyproject.toml | 49 ++++---------------------------------------------
1 file changed, 4 insertions(+), 45 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 6c0b6e3f7..847db8d2b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -69,6 +69,8 @@ docs = [
]
litellm = [
"litellm>=1.73.1,<2.0.0",
+ # https://github.com/BerriAI/litellm/issues/13711
+ "openai<1.100.0",
]
llamaapi = [
"llama-api-client>=0.1.0,<1.0.0",
@@ -106,50 +108,7 @@ a2a = [
"starlette>=0.46.2,<1.0.0",
]
all = [
- # anthropic
- "anthropic>=0.21.0,<1.0.0",
-
- # dev
- "commitizen>=4.4.0,<5.0.0",
- "hatch>=1.0.0,<2.0.0",
- "moto>=5.1.0,<6.0.0",
- "mypy>=1.15.0,<2.0.0",
- "pre-commit>=3.2.0,<4.2.0",
- "pytest>=8.0.0,<9.0.0",
- "pytest-asyncio>=0.26.0,<0.27.0",
- "pytest-cov>=4.1.0,<5.0.0",
- "pytest-xdist>=3.0.0,<4.0.0",
- "ruff>=0.4.4,<0.5.0",
-
- # docs
- "sphinx>=5.0.0,<6.0.0",
- "sphinx-rtd-theme>=1.0.0,<2.0.0",
- "sphinx-autodoc-typehints>=1.12.0,<2.0.0",
-
- # litellm
- "litellm>=1.72.6,<1.73.0",
-
- # llama
- "llama-api-client>=0.1.0,<1.0.0",
-
- # mistral
- "mistralai>=1.8.2",
-
- # ollama
- "ollama>=0.4.8,<1.0.0",
-
- # openai
- "openai>=1.68.0,<2.0.0",
-
- # otel
- "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
-
- # a2a
- "a2a-sdk[sql]>=0.3.0,<0.4.0",
- "uvicorn>=0.34.2,<1.0.0",
- "httpx>=0.28.1,<1.0.0",
- "fastapi>=0.115.12,<1.0.0",
- "starlette>=0.46.2,<1.0.0",
+ "strands-agents[a2a,anthropic,dev,docs,litellm,llamaapi,mistral,ollama,openai,otel]",
]
[tool.hatch.version]
@@ -161,7 +120,7 @@ features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mis
dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
- "strands-agents @ {root:uri}"
+ "strands-agents @ {root:uri}",
]
[tool.hatch.envs.hatch-static-analysis.scripts]
From b1df148fbc89bb057348a897ea42fa3c6501ac63 Mon Sep 17 00:00:00 2001
From: poshinchen
Date: Mon, 18 Aug 2025 16:06:28 -0400
Subject: [PATCH 038/104] fix: append blank text content if assistant content
is empty (#677)
---
src/strands/event_loop/streaming.py | 6 +++---
tests/strands/event_loop/test_streaming.py | 2 ++
2 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py
index f4048a65c..1f8c260a4 100644
--- a/src/strands/event_loop/streaming.py
+++ b/src/strands/event_loop/streaming.py
@@ -40,10 +40,12 @@ def remove_blank_messages_content_text(messages: Messages) -> Messages:
# only modify assistant messages
if "role" in message and message["role"] != "assistant":
continue
-
if "content" in message:
content = message["content"]
has_tool_use = any("toolUse" in item for item in content)
+ if len(content) == 0:
+ content.append({"text": "[blank text]"})
+ continue
if has_tool_use:
# Remove blank 'text' items for assistant messages
@@ -273,7 +275,6 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
async for chunk in chunks:
yield {"callback": {"event": chunk}}
-
if "messageStart" in chunk:
state["message"] = handle_message_start(chunk["messageStart"], state["message"])
elif "contentBlockStart" in chunk:
@@ -313,7 +314,6 @@ async def stream_messages(
logger.debug("model=<%s> | streaming messages", model)
messages = remove_blank_messages_content_text(messages)
-
chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt)
async for event in process_stream(chunks):
diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py
index b1cc312c2..66deb282c 100644
--- a/tests/strands/event_loop/test_streaming.py
+++ b/tests/strands/event_loop/test_streaming.py
@@ -26,6 +26,7 @@ def moto_autouse(moto_env, moto_mock_aws):
{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}, {"toolUse": {}}]},
{"role": "assistant", "content": [{"text": ""}, {"toolUse": {}}]},
{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]},
+ {"role": "assistant", "content": []},
{"role": "assistant"},
{"role": "user", "content": [{"text": " \n"}]},
],
@@ -33,6 +34,7 @@ def moto_autouse(moto_env, moto_mock_aws):
{"role": "assistant", "content": [{"text": "a"}, {"toolUse": {}}]},
{"role": "assistant", "content": [{"toolUse": {}}]},
{"role": "assistant", "content": [{"text": "a"}, {"text": "[blank text]"}]},
+ {"role": "assistant", "content": [{"text": "[blank text]"}]},
{"role": "assistant"},
{"role": "user", "content": [{"text": " \n"}]},
],
From cfcf93dc781cc0300f3faae19530c527bbe595ad Mon Sep 17 00:00:00 2001
From: Oz Altagar
Date: Tue, 19 Aug 2025 00:06:08 +0300
Subject: [PATCH 039/104] feat: add cached token metrics support for Amazon
Bedrock (#531)
* feat: add cached token metrics support for Amazon Bedrock
- Add optional cacheReadInputTokens and cacheWriteInputTokens fields to Usage TypedDict
- Update EventLoopMetrics to accumulate cached token metrics
- Add OpenTelemetry instrumentation for cached token telemetry
- Enhance metrics summary display to show cached token information
- Maintain 100% backward compatibility with existing Usage objects
- Add comprehensive test coverage for cached token functionality
Resolves #529
* feat: updated cached read/write input token metrics
---------
Co-authored-by: poshinchen
---
src/strands/telemetry/metrics.py | 45 +++++++++++++++++++---
src/strands/telemetry/metrics_constants.py | 2 +
src/strands/types/event_loop.py | 16 +++++---
tests/strands/event_loop/test_streaming.py | 12 ++++++
tests/strands/telemetry/test_metrics.py | 8 ++--
5 files changed, 66 insertions(+), 17 deletions(-)
diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py
index 332ab2ae3..883273f64 100644
--- a/src/strands/telemetry/metrics.py
+++ b/src/strands/telemetry/metrics.py
@@ -11,7 +11,7 @@
from ..telemetry import metrics_constants as constants
from ..types.content import Message
-from ..types.streaming import Metrics, Usage
+from ..types.event_loop import Metrics, Usage
from ..types.tools import ToolUse
logger = logging.getLogger(__name__)
@@ -264,6 +264,21 @@ def update_usage(self, usage: Usage) -> None:
self.accumulated_usage["outputTokens"] += usage["outputTokens"]
self.accumulated_usage["totalTokens"] += usage["totalTokens"]
+ # Handle optional cached token metrics
+ if "cacheReadInputTokens" in usage:
+ cache_read_tokens = usage["cacheReadInputTokens"]
+ self._metrics_client.event_loop_cache_read_input_tokens.record(cache_read_tokens)
+ self.accumulated_usage["cacheReadInputTokens"] = (
+ self.accumulated_usage.get("cacheReadInputTokens", 0) + cache_read_tokens
+ )
+
+ if "cacheWriteInputTokens" in usage:
+ cache_write_tokens = usage["cacheWriteInputTokens"]
+ self._metrics_client.event_loop_cache_write_input_tokens.record(cache_write_tokens)
+ self.accumulated_usage["cacheWriteInputTokens"] = (
+ self.accumulated_usage.get("cacheWriteInputTokens", 0) + cache_write_tokens
+ )
+
def update_metrics(self, metrics: Metrics) -> None:
"""Update the accumulated performance metrics with new metrics data.
@@ -325,11 +340,21 @@ def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_name
f"āā Cycles: total={summary['total_cycles']}, avg_time={summary['average_cycle_time']:.3f}s, "
f"total_time={summary['total_duration']:.3f}s"
)
- yield (
- f"āā Tokens: in={summary['accumulated_usage']['inputTokens']}, "
- f"out={summary['accumulated_usage']['outputTokens']}, "
- f"total={summary['accumulated_usage']['totalTokens']}"
- )
+
+ # Build token display with optional cached tokens
+ token_parts = [
+ f"in={summary['accumulated_usage']['inputTokens']}",
+ f"out={summary['accumulated_usage']['outputTokens']}",
+ f"total={summary['accumulated_usage']['totalTokens']}",
+ ]
+
+ # Add cached token info if present
+ if summary["accumulated_usage"].get("cacheReadInputTokens"):
+ token_parts.append(f"cache_read_input_tokens={summary['accumulated_usage']['cacheReadInputTokens']}")
+ if summary["accumulated_usage"].get("cacheWriteInputTokens"):
+ token_parts.append(f"cache_write_input_tokens={summary['accumulated_usage']['cacheWriteInputTokens']}")
+
+ yield f"āā Tokens: {', '.join(token_parts)}"
yield f"āā Bedrock Latency: {summary['accumulated_metrics']['latencyMs']}ms"
yield "āā Tool Usage:"
@@ -421,6 +446,8 @@ class MetricsClient:
event_loop_latency: Histogram
event_loop_input_tokens: Histogram
event_loop_output_tokens: Histogram
+ event_loop_cache_read_input_tokens: Histogram
+ event_loop_cache_write_input_tokens: Histogram
tool_call_count: Counter
tool_success_count: Counter
@@ -474,3 +501,9 @@ def create_instruments(self) -> None:
self.event_loop_output_tokens = self.meter.create_histogram(
name=constants.STRANDS_EVENT_LOOP_OUTPUT_TOKENS, unit="token"
)
+ self.event_loop_cache_read_input_tokens = self.meter.create_histogram(
+ name=constants.STRANDS_EVENT_LOOP_CACHE_READ_INPUT_TOKENS, unit="token"
+ )
+ self.event_loop_cache_write_input_tokens = self.meter.create_histogram(
+ name=constants.STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS, unit="token"
+ )
diff --git a/src/strands/telemetry/metrics_constants.py b/src/strands/telemetry/metrics_constants.py
index b622eebff..f8fac34da 100644
--- a/src/strands/telemetry/metrics_constants.py
+++ b/src/strands/telemetry/metrics_constants.py
@@ -13,3 +13,5 @@
STRANDS_EVENT_LOOP_CYCLE_DURATION = "strands.event_loop.cycle_duration"
STRANDS_EVENT_LOOP_INPUT_TOKENS = "strands.event_loop.input.tokens"
STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens"
+STRANDS_EVENT_LOOP_CACHE_READ_INPUT_TOKENS = "strands.event_loop.cache_read.input.tokens"
+STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS = "strands.event_loop.cache_write.input.tokens"
diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py
index 7be33b6fd..2c240972b 100644
--- a/src/strands/types/event_loop.py
+++ b/src/strands/types/event_loop.py
@@ -2,21 +2,25 @@
from typing import Literal
-from typing_extensions import TypedDict
+from typing_extensions import Required, TypedDict
-class Usage(TypedDict):
+class Usage(TypedDict, total=False):
"""Token usage information for model interactions.
Attributes:
- inputTokens: Number of tokens sent in the request to the model..
+ inputTokens: Number of tokens sent in the request to the model.
outputTokens: Number of tokens that the model generated for the request.
totalTokens: Total number of tokens (input + output).
+ cacheReadInputTokens: Number of tokens read from cache (optional).
+ cacheWriteInputTokens: Number of tokens written to cache (optional).
"""
- inputTokens: int
- outputTokens: int
- totalTokens: int
+ inputTokens: Required[int]
+ outputTokens: Required[int]
+ totalTokens: Required[int]
+ cacheReadInputTokens: int
+ cacheWriteInputTokens: int
class Metrics(TypedDict):
diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py
index 66deb282c..7760c498a 100644
--- a/tests/strands/event_loop/test_streaming.py
+++ b/tests/strands/event_loop/test_streaming.py
@@ -277,6 +277,18 @@ def test_extract_usage_metrics():
assert tru_usage == exp_usage and tru_metrics == exp_metrics
+def test_extract_usage_metrics_with_cache_tokens():
+ event = {
+ "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0, "cacheReadInputTokens": 0},
+ "metrics": {"latencyMs": 0},
+ }
+
+ tru_usage, tru_metrics = strands.event_loop.streaming.extract_usage_metrics(event)
+ exp_usage, exp_metrics = event["usage"], event["metrics"]
+
+ assert tru_usage == exp_usage and tru_metrics == exp_metrics
+
+
@pytest.mark.parametrize(
("response", "exp_events"),
[
diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py
index 215e1efde..12db81908 100644
--- a/tests/strands/telemetry/test_metrics.py
+++ b/tests/strands/telemetry/test_metrics.py
@@ -90,6 +90,7 @@ def usage(request):
"inputTokens": 1,
"outputTokens": 2,
"totalTokens": 3,
+ "cacheWriteInputTokens": 2,
}
if hasattr(request, "param"):
params.update(request.param)
@@ -315,17 +316,14 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_met
event_loop_metrics.update_usage(usage)
tru_usage = event_loop_metrics.accumulated_usage
- exp_usage = Usage(
- inputTokens=3,
- outputTokens=6,
- totalTokens=9,
- )
+ exp_usage = Usage(inputTokens=3, outputTokens=6, totalTokens=9, cacheWriteInputTokens=6)
assert tru_usage == exp_usage
mock_get_meter_provider.return_value.get_meter.assert_called()
metrics_client = event_loop_metrics._metrics_client
metrics_client.event_loop_input_tokens.record.assert_called()
metrics_client.event_loop_output_tokens.record.assert_called()
+ metrics_client.event_loop_cache_write_input_tokens.record.assert_called()
def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get_meter_provider):
From c087f1883dcad7481de2499cb2d2d891c19e4ee7 Mon Sep 17 00:00:00 2001
From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com>
Date: Tue, 19 Aug 2025 23:06:45 +0800
Subject: [PATCH 040/104] fix: fix non-serializable parameter of agent from
toolUse block (#568)
* fix: fix non-serializable parameter of agent from toolUse block
* feat: Add configuration option to MCP Client for server init timeout (#657)
Co-authored-by: Harry Wilton
* fix: Bedrock hang when exception occurs during message conversion (#643)
Previously (#642) bedrock would hang during message conversion because the exception was not being caught and thus the queue was always empty. Now all exceptions during conversion are caught
Co-authored-by: Mackenzie Zastrow
* fix: only include parameters that defined in tool spec
---------
Co-authored-by: Jack Yuan
Co-authored-by: fhwilton55 <81768750+fhwilton55@users.noreply.github.com>
Co-authored-by: Harry Wilton
Co-authored-by: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com>
Co-authored-by: Mackenzie Zastrow
---
src/strands/agent/agent.py | 33 +++++++-
tests/strands/agent/test_agent.py | 127 ++++++++----------------------
2 files changed, 65 insertions(+), 95 deletions(-)
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index 38e687af2..acc6a7650 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -642,8 +642,11 @@ def _record_tool_execution(
tool_result: The result returned by the tool.
user_message_override: Optional custom message to include.
"""
+ # Filter tool input parameters to only include those defined in tool spec
+ filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"])
+
# Create user message describing the tool call
- input_parameters = json.dumps(tool["input"], default=lambda o: f"<>")
+ input_parameters = json.dumps(filtered_input, default=lambda o: f"<>")
user_msg_content: list[ContentBlock] = [
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")}
@@ -653,6 +656,13 @@ def _record_tool_execution(
if user_message_override:
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})
+ # Create filtered tool use for message history
+ filtered_tool: ToolUse = {
+ "toolUseId": tool["toolUseId"],
+ "name": tool["name"],
+ "input": filtered_input,
+ }
+
# Create the message sequence
user_msg: Message = {
"role": "user",
@@ -660,7 +670,7 @@ def _record_tool_execution(
}
tool_use_msg: Message = {
"role": "assistant",
- "content": [{"toolUse": tool}],
+ "content": [{"toolUse": filtered_tool}],
}
tool_result_msg: Message = {
"role": "user",
@@ -717,6 +727,25 @@ def _end_agent_trace_span(
self.tracer.end_agent_span(**trace_attributes)
+ def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]:
+ """Filter input parameters to only include those defined in the tool specification.
+
+ Args:
+ tool_name: Name of the tool to get specification for
+ input_params: Original input parameters
+
+ Returns:
+ Filtered parameters containing only those defined in tool spec
+ """
+ all_tools_config = self.tool_registry.get_all_tools_config()
+ tool_spec = all_tools_config.get(tool_name)
+
+ if not tool_spec or "inputSchema" not in tool_spec:
+ return input_params.copy()
+
+ properties = tool_spec["inputSchema"]["json"]["properties"]
+ return {k: v for k, v in input_params.items() if k in properties}
+
def _append_message(self, message: Message) -> None:
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
self.messages.append(message)
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index ca66ca2bf..444232455 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -1738,99 +1738,7 @@ def test_agent_tool_non_serializable_parameter_filtering(agent, mock_randint):
tool_call_text = user_message["content"][1]["text"]
assert "agent.tool.tool_decorated direct tool call." in tool_call_text
assert '"random_string": "test_value"' in tool_call_text
- assert '"non_serializable_agent": "<>"' in tool_call_text
-
-
-def test_agent_tool_multiple_non_serializable_types(agent, mock_randint):
- """Test filtering of various non-serializable object types."""
- mock_randint.return_value = 123
-
- # Create various non-serializable objects
- class CustomClass:
- def __init__(self, value):
- self.value = value
-
- non_serializable_objects = {
- "agent": Agent(),
- "custom_object": CustomClass("test"),
- "function": lambda x: x,
- "set_object": {1, 2, 3},
- "complex_number": 3 + 4j,
- "serializable_string": "this_should_remain",
- "serializable_number": 42,
- "serializable_list": [1, 2, 3],
- "serializable_dict": {"key": "value"},
- }
-
- # This should not crash
- result = agent.tool.tool_decorated(random_string="test_filtering", **non_serializable_objects)
-
- # Verify tool executed successfully
- expected_result = {
- "content": [{"text": "test_filtering"}],
- "status": "success",
- "toolUseId": "tooluse_tool_decorated_123",
- }
- assert result == expected_result
-
- # Check the recorded message for proper parameter filtering
- assert len(agent.messages) > 0
- user_message = agent.messages[0]
- tool_call_text = user_message["content"][0]["text"]
-
- # Verify serializable objects remain unchanged
- assert '"serializable_string": "this_should_remain"' in tool_call_text
- assert '"serializable_number": 42' in tool_call_text
- assert '"serializable_list": [1, 2, 3]' in tool_call_text
- assert '"serializable_dict": {"key": "value"}' in tool_call_text
-
- # Verify non-serializable objects are replaced with descriptive strings
- assert '"agent": "<>"' in tool_call_text
- assert (
- '"custom_object": "<.CustomClass>>"'
- in tool_call_text
- )
- assert '"function": "<>"' in tool_call_text
- assert '"set_object": "<>"' in tool_call_text
- assert '"complex_number": "<>"' in tool_call_text
-
-
-def test_agent_tool_serialization_edge_cases(agent, mock_randint):
- """Test edge cases in parameter serialization filtering."""
- mock_randint.return_value = 999
-
- # Test with None values, empty containers, and nested structures
- edge_case_params = {
- "none_value": None,
- "empty_list": [],
- "empty_dict": {},
- "nested_list_with_non_serializable": [1, 2, Agent()], # This should be filtered out
- "nested_dict_serializable": {"nested": {"key": "value"}}, # This should remain
- }
-
- result = agent.tool.tool_decorated(random_string="edge_cases", **edge_case_params)
-
- # Verify successful execution
- expected_result = {
- "content": [{"text": "edge_cases"}],
- "status": "success",
- "toolUseId": "tooluse_tool_decorated_999",
- }
- assert result == expected_result
-
- # Check parameter filtering in recorded message
- assert len(agent.messages) > 0
- user_message = agent.messages[0]
- tool_call_text = user_message["content"][0]["text"]
-
- # Verify serializable values remain
- assert '"none_value": null' in tool_call_text
- assert '"empty_list": []' in tool_call_text
- assert '"empty_dict": {}' in tool_call_text
- assert '"nested_dict_serializable": {"nested": {"key": "value"}}' in tool_call_text
-
- # Verify non-serializable nested structure is replaced
- assert '"nested_list_with_non_serializable": [1, 2, "<>"]' in tool_call_text
+ assert '"non_serializable_agent": "<>"' not in tool_call_text
def test_agent_tool_no_non_serializable_parameters(agent, mock_randint):
@@ -1882,3 +1790,36 @@ def test_agent_tool_record_direct_tool_call_disabled_with_non_serializable(agent
# Verify no messages were recorded
assert len(agent.messages) == 0
+
+
+def test_agent_tool_call_parameter_filtering_integration(mock_randint):
+ """Test that tool calls properly filter parameters in message recording."""
+ mock_randint.return_value = 42
+
+ @strands.tool
+ def test_tool(action: str) -> str:
+ """Test tool with single parameter."""
+ return action
+
+ agent = Agent(tools=[test_tool])
+
+ # Call tool with extra non-spec parameters
+ result = agent.tool.test_tool(
+ action="test_value",
+ agent=agent, # Should be filtered out
+ extra_param="filtered", # Should be filtered out
+ )
+
+ # Verify tool executed successfully
+ assert result["status"] == "success"
+ assert result["content"] == [{"text": "test_value"}]
+
+ # Check that only spec parameters are recorded in message history
+ assert len(agent.messages) > 0
+ user_message = agent.messages[0]
+ tool_call_text = user_message["content"][0]["text"]
+
+ # Should only contain the 'action' parameter
+ assert '"action": "test_value"' in tool_call_text
+ assert '"agent"' not in tool_call_text
+ assert '"extra_param"' not in tool_call_text
From 17ccdd2df3ee7a213fe59a24d51a5ea238879117 Mon Sep 17 00:00:00 2001
From: vawsgit <147627358+vawsgit@users.noreply.github.com>
Date: Tue, 19 Aug 2025 10:15:51 -0500
Subject: [PATCH 041/104] chore: add .DS_Store to .gitignore (#681)
---
.gitignore | 1 +
1 file changed, 1 insertion(+)
diff --git a/.gitignore b/.gitignore
index c27d1d902..888a96bbc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
+.DS_Store
build
__pycache__*
.coverage*
From ef18a255d5949b9ebbd46f08575ea881b5c64106 Mon Sep 17 00:00:00 2001
From: Jeremiah
Date: Wed, 20 Aug 2025 13:21:02 -0400
Subject: [PATCH 042/104] feat(a2a): support A2A FileParts and DataParts (#596)
Co-authored-by: jer
---
src/strands/multiagent/a2a/executor.py | 185 +++-
tests/strands/multiagent/a2a/test_executor.py | 787 +++++++++++++++++-
2 files changed, 947 insertions(+), 25 deletions(-)
diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py
index 5bf9cbfe9..74ecc6531 100644
--- a/src/strands/multiagent/a2a/executor.py
+++ b/src/strands/multiagent/a2a/executor.py
@@ -8,18 +8,29 @@
streamed requests to the A2AServer.
"""
+import json
import logging
-from typing import Any
+import mimetypes
+from typing import Any, Literal
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.tasks import TaskUpdater
-from a2a.types import InternalError, Part, TaskState, TextPart, UnsupportedOperationError
+from a2a.types import DataPart, FilePart, InternalError, Part, TaskState, TextPart, UnsupportedOperationError
from a2a.utils import new_agent_text_message, new_task
from a2a.utils.errors import ServerError
from ...agent.agent import Agent as SAAgent
from ...agent.agent import AgentResult as SAAgentResult
+from ...types.content import ContentBlock
+from ...types.media import (
+ DocumentContent,
+ DocumentSource,
+ ImageContent,
+ ImageSource,
+ VideoContent,
+ VideoSource,
+)
logger = logging.getLogger(__name__)
@@ -31,6 +42,12 @@ class StrandsA2AExecutor(AgentExecutor):
and converts Strands Agent responses to A2A protocol events.
"""
+ # Default formats for each file type when MIME type is unavailable or unrecognized
+ DEFAULT_FORMATS = {"document": "txt", "image": "png", "video": "mp4", "unknown": "txt"}
+
+ # Handle special cases where format differs from extension
+ FORMAT_MAPPINGS = {"jpg": "jpeg", "htm": "html", "3gp": "three_gp", "3gpp": "three_gp", "3g2": "three_gp"}
+
def __init__(self, agent: SAAgent):
"""Initialize a StrandsA2AExecutor.
@@ -78,10 +95,16 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater
context: The A2A request context, containing the user's input and other metadata.
updater: The task updater for managing task state and sending updates.
"""
- logger.info("Executing request in streaming mode")
- user_input = context.get_user_input()
+ # Convert A2A message parts to Strands ContentBlocks
+ if context.message and hasattr(context.message, "parts"):
+ content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts)
+ if not content_blocks:
+ raise ValueError("No content blocks available")
+ else:
+ raise ValueError("No content blocks available")
+
try:
- async for event in self.agent.stream_async(user_input):
+ async for event in self.agent.stream_async(content_blocks):
await self._handle_streaming_event(event, updater)
except Exception:
logger.exception("Error in streaming execution")
@@ -146,3 +169,155 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None
"""
logger.warning("Cancellation requested but not supported")
raise ServerError(error=UnsupportedOperationError())
+
+ def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]:
+ """Classify file type based on MIME type.
+
+ Args:
+ mime_type: The MIME type of the file
+
+ Returns:
+ The classified file type
+ """
+ if not mime_type:
+ return "unknown"
+
+ mime_type = mime_type.lower()
+
+ if mime_type.startswith("image/"):
+ return "image"
+ elif mime_type.startswith("video/"):
+ return "video"
+ elif (
+ mime_type.startswith("text/")
+ or mime_type.startswith("application/")
+ or mime_type in ["application/pdf", "application/json", "application/xml"]
+ ):
+ return "document"
+ else:
+ return "unknown"
+
+ def _get_file_format_from_mime_type(self, mime_type: str | None, file_type: str) -> str:
+ """Extract file format from MIME type using Python's mimetypes library.
+
+ Args:
+ mime_type: The MIME type of the file
+ file_type: The classified file type (image, video, document, txt)
+
+ Returns:
+ The file format string
+ """
+ if not mime_type:
+ return self.DEFAULT_FORMATS.get(file_type, "txt")
+
+ mime_type = mime_type.lower()
+
+ # Extract subtype from MIME type and check existing format mappings
+ if "/" in mime_type:
+ subtype = mime_type.split("/")[-1]
+ if subtype in self.FORMAT_MAPPINGS:
+ return self.FORMAT_MAPPINGS[subtype]
+
+ # Use mimetypes library to find extensions for the MIME type
+ extensions = mimetypes.guess_all_extensions(mime_type)
+
+ if extensions:
+ extension = extensions[0][1:] # Remove the leading dot
+ return self.FORMAT_MAPPINGS.get(extension, extension)
+
+ # Fallback to defaults for unknown MIME types
+ return self.DEFAULT_FORMATS.get(file_type, "txt")
+
+ def _strip_file_extension(self, file_name: str) -> str:
+ """Strip the file extension from a file name.
+
+ Args:
+ file_name: The original file name with extension
+
+ Returns:
+ The file name without extension
+ """
+ if "." in file_name:
+ return file_name.rsplit(".", 1)[0]
+ return file_name
+
+ def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[ContentBlock]:
+ """Convert A2A message parts to Strands ContentBlocks.
+
+ Args:
+ parts: List of A2A Part objects
+
+ Returns:
+ List of Strands ContentBlock objects
+ """
+ content_blocks: list[ContentBlock] = []
+
+ for part in parts:
+ try:
+ part_root = part.root
+
+ if isinstance(part_root, TextPart):
+ # Handle TextPart
+ content_blocks.append(ContentBlock(text=part_root.text))
+
+ elif isinstance(part_root, FilePart):
+ # Handle FilePart
+ file_obj = part_root.file
+ mime_type = getattr(file_obj, "mime_type", None)
+ raw_file_name = getattr(file_obj, "name", "FileNameNotProvided")
+ file_name = self._strip_file_extension(raw_file_name)
+ file_type = self._get_file_type_from_mime_type(mime_type)
+ file_format = self._get_file_format_from_mime_type(mime_type, file_type)
+
+ # Handle FileWithBytes vs FileWithUri
+ bytes_data = getattr(file_obj, "bytes", None)
+ uri_data = getattr(file_obj, "uri", None)
+
+ if bytes_data:
+ if file_type == "image":
+ content_blocks.append(
+ ContentBlock(
+ image=ImageContent(
+ format=file_format, # type: ignore
+ source=ImageSource(bytes=bytes_data),
+ )
+ )
+ )
+ elif file_type == "video":
+ content_blocks.append(
+ ContentBlock(
+ video=VideoContent(
+ format=file_format, # type: ignore
+ source=VideoSource(bytes=bytes_data),
+ )
+ )
+ )
+ else: # document or unknown
+ content_blocks.append(
+ ContentBlock(
+ document=DocumentContent(
+ format=file_format, # type: ignore
+ name=file_name,
+ source=DocumentSource(bytes=bytes_data),
+ )
+ )
+ )
+ # Handle FileWithUri
+ elif uri_data:
+ # For URI files, create a text representation since Strands ContentBlocks expect bytes
+ content_blocks.append(
+ ContentBlock(
+ text="[File: %s (%s)] - Referenced file at: %s" % (file_name, mime_type, uri_data)
+ )
+ )
+ elif isinstance(part_root, DataPart):
+ # Handle DataPart - convert structured data to JSON text
+ try:
+ data_text = json.dumps(part_root.data, indent=2)
+ content_blocks.append(ContentBlock(text="[Structured Data]\n%s" % data_text))
+ except Exception:
+ logger.exception("Failed to serialize data part")
+ except Exception:
+ logger.exception("Error processing part")
+
+ return content_blocks
diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py
index 77645fc73..3f63119f2 100644
--- a/tests/strands/multiagent/a2a/test_executor.py
+++ b/tests/strands/multiagent/a2a/test_executor.py
@@ -3,11 +3,12 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
-from a2a.types import UnsupportedOperationError
+from a2a.types import InternalError, UnsupportedOperationError
from a2a.utils.errors import ServerError
from strands.agent.agent_result import AgentResult as SAAgentResult
from strands.multiagent.a2a.executor import StrandsA2AExecutor
+from strands.types.content import ContentBlock
def test_executor_initialization(mock_strands_agent):
@@ -17,18 +18,304 @@ def test_executor_initialization(mock_strands_agent):
assert executor.agent == mock_strands_agent
+def test_classify_file_type():
+ """Test file type classification based on MIME type."""
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Test image types
+ assert executor._get_file_type_from_mime_type("image/jpeg") == "image"
+ assert executor._get_file_type_from_mime_type("image/png") == "image"
+
+ # Test video types
+ assert executor._get_file_type_from_mime_type("video/mp4") == "video"
+ assert executor._get_file_type_from_mime_type("video/mpeg") == "video"
+
+ # Test document types
+ assert executor._get_file_type_from_mime_type("text/plain") == "document"
+ assert executor._get_file_type_from_mime_type("application/pdf") == "document"
+ assert executor._get_file_type_from_mime_type("application/json") == "document"
+
+ # Test unknown/edge cases
+ assert executor._get_file_type_from_mime_type("audio/mp3") == "unknown"
+ assert executor._get_file_type_from_mime_type(None) == "unknown"
+ assert executor._get_file_type_from_mime_type("") == "unknown"
+
+
+def test_get_file_format_from_mime_type():
+ """Test file format extraction from MIME type using mimetypes library."""
+ executor = StrandsA2AExecutor(MagicMock())
+ assert executor._get_file_format_from_mime_type("image/jpeg", "image") == "jpeg"
+ assert executor._get_file_format_from_mime_type("image/png", "image") == "png"
+ assert executor._get_file_format_from_mime_type("image/unknown", "image") == "png"
+
+ # Test video formats
+ assert executor._get_file_format_from_mime_type("video/mp4", "video") == "mp4"
+ assert executor._get_file_format_from_mime_type("video/3gpp", "video") == "three_gp"
+ assert executor._get_file_format_from_mime_type("video/unknown", "video") == "mp4"
+
+ # Test document formats
+ assert executor._get_file_format_from_mime_type("application/pdf", "document") == "pdf"
+ assert executor._get_file_format_from_mime_type("text/plain", "document") == "txt"
+ assert executor._get_file_format_from_mime_type("application/unknown", "document") == "txt"
+
+ # Test None/empty cases
+ assert executor._get_file_format_from_mime_type(None, "image") == "png"
+ assert executor._get_file_format_from_mime_type("", "video") == "mp4"
+
+
+def test_strip_file_extension():
+ """Test file extension stripping."""
+ executor = StrandsA2AExecutor(MagicMock())
+
+ assert executor._strip_file_extension("test.txt") == "test"
+ assert executor._strip_file_extension("document.pdf") == "document"
+ assert executor._strip_file_extension("image.jpeg") == "image"
+ assert executor._strip_file_extension("no_extension") == "no_extension"
+ assert executor._strip_file_extension("multiple.dots.file.ext") == "multiple.dots.file"
+
+
+def test_convert_a2a_parts_to_content_blocks_text_part():
+ """Test conversion of TextPart to ContentBlock."""
+ from a2a.types import TextPart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Mock TextPart with proper spec
+ text_part = MagicMock(spec=TextPart)
+ text_part.text = "Hello, world!"
+
+ # Mock Part with TextPart root
+ part = MagicMock()
+ part.root = text_part
+
+ result = executor._convert_a2a_parts_to_content_blocks([part])
+
+ assert len(result) == 1
+ assert result[0] == ContentBlock(text="Hello, world!")
+
+
+def test_convert_a2a_parts_to_content_blocks_file_part_image_bytes():
+ """Test conversion of FilePart with image bytes to ContentBlock."""
+ from a2a.types import FilePart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Create test image bytes (no base64 encoding needed)
+ test_bytes = b"fake_image_data"
+
+ # Mock file object
+ file_obj = MagicMock()
+ file_obj.name = "test_image.jpeg"
+ file_obj.mime_type = "image/jpeg"
+ file_obj.bytes = test_bytes
+ file_obj.uri = None
+
+ # Mock FilePart with proper spec
+ file_part = MagicMock(spec=FilePart)
+ file_part.file = file_obj
+
+ # Mock Part with FilePart root
+ part = MagicMock()
+ part.root = file_part
+
+ result = executor._convert_a2a_parts_to_content_blocks([part])
+
+ assert len(result) == 1
+ content_block = result[0]
+ assert "image" in content_block
+ assert content_block["image"]["format"] == "jpeg"
+ assert content_block["image"]["source"]["bytes"] == test_bytes
+
+
+def test_convert_a2a_parts_to_content_blocks_file_part_video_bytes():
+ """Test conversion of FilePart with video bytes to ContentBlock."""
+ from a2a.types import FilePart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Create test video bytes (no base64 encoding needed)
+ test_bytes = b"fake_video_data"
+
+ # Mock file object
+ file_obj = MagicMock()
+ file_obj.name = "test_video.mp4"
+ file_obj.mime_type = "video/mp4"
+ file_obj.bytes = test_bytes
+ file_obj.uri = None
+
+ # Mock FilePart with proper spec
+ file_part = MagicMock(spec=FilePart)
+ file_part.file = file_obj
+
+ # Mock Part with FilePart root
+ part = MagicMock()
+ part.root = file_part
+
+ result = executor._convert_a2a_parts_to_content_blocks([part])
+
+ assert len(result) == 1
+ content_block = result[0]
+ assert "video" in content_block
+ assert content_block["video"]["format"] == "mp4"
+ assert content_block["video"]["source"]["bytes"] == test_bytes
+
+
+def test_convert_a2a_parts_to_content_blocks_file_part_document_bytes():
+ """Test conversion of FilePart with document bytes to ContentBlock."""
+ from a2a.types import FilePart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Create test document bytes (no base64 encoding needed)
+ test_bytes = b"fake_document_data"
+
+ # Mock file object
+ file_obj = MagicMock()
+ file_obj.name = "test_document.pdf"
+ file_obj.mime_type = "application/pdf"
+ file_obj.bytes = test_bytes
+ file_obj.uri = None
+
+ # Mock FilePart with proper spec
+ file_part = MagicMock(spec=FilePart)
+ file_part.file = file_obj
+
+ # Mock Part with FilePart root
+ part = MagicMock()
+ part.root = file_part
+
+ result = executor._convert_a2a_parts_to_content_blocks([part])
+
+ assert len(result) == 1
+ content_block = result[0]
+ assert "document" in content_block
+ assert content_block["document"]["format"] == "pdf"
+ assert content_block["document"]["name"] == "test_document"
+ assert content_block["document"]["source"]["bytes"] == test_bytes
+
+
+def test_convert_a2a_parts_to_content_blocks_file_part_uri():
+ """Test conversion of FilePart with URI to ContentBlock."""
+ from a2a.types import FilePart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Mock file object with URI
+ file_obj = MagicMock()
+ file_obj.name = "test_image.png"
+ file_obj.mime_type = "image/png"
+ file_obj.bytes = None
+ file_obj.uri = "https://example.com/image.png"
+
+ # Mock FilePart with proper spec
+ file_part = MagicMock(spec=FilePart)
+ file_part.file = file_obj
+
+ # Mock Part with FilePart root
+ part = MagicMock()
+ part.root = file_part
+
+ result = executor._convert_a2a_parts_to_content_blocks([part])
+
+ assert len(result) == 1
+ content_block = result[0]
+ assert "text" in content_block
+ assert "test_image" in content_block["text"]
+ assert "https://example.com/image.png" in content_block["text"]
+
+
+def test_convert_a2a_parts_to_content_blocks_file_part_with_bytes():
+ """Test conversion of FilePart with bytes data."""
+ from a2a.types import FilePart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Mock file object with bytes (no validation needed since no decoding)
+ file_obj = MagicMock()
+ file_obj.name = "test_image.png"
+ file_obj.mime_type = "image/png"
+ file_obj.bytes = b"some_binary_data"
+ file_obj.uri = None
+
+ # Mock FilePart with proper spec
+ file_part = MagicMock(spec=FilePart)
+ file_part.file = file_obj
+
+ # Mock Part with FilePart root
+ part = MagicMock()
+ part.root = file_part
+
+ result = executor._convert_a2a_parts_to_content_blocks([part])
+
+ assert len(result) == 1
+ content_block = result[0]
+ assert "image" in content_block
+ assert content_block["image"]["source"]["bytes"] == b"some_binary_data"
+
+
+def test_convert_a2a_parts_to_content_blocks_data_part():
+ """Test conversion of DataPart to ContentBlock."""
+ from a2a.types import DataPart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Mock DataPart with proper spec
+ test_data = {"key": "value", "number": 42}
+ data_part = MagicMock(spec=DataPart)
+ data_part.data = test_data
+
+ # Mock Part with DataPart root
+ part = MagicMock()
+ part.root = data_part
+
+ result = executor._convert_a2a_parts_to_content_blocks([part])
+
+ assert len(result) == 1
+ content_block = result[0]
+ assert "text" in content_block
+ assert "[Structured Data]" in content_block["text"]
+ assert "key" in content_block["text"]
+ assert "value" in content_block["text"]
+
+
+def test_convert_a2a_parts_to_content_blocks_mixed_parts():
+ """Test conversion of mixed A2A parts to ContentBlocks."""
+ from a2a.types import DataPart, TextPart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Mock TextPart with proper spec
+ text_part = MagicMock(spec=TextPart)
+ text_part.text = "Text content"
+ text_part_mock = MagicMock()
+ text_part_mock.root = text_part
+
+ # Mock DataPart with proper spec
+ data_part = MagicMock(spec=DataPart)
+ data_part.data = {"test": "data"}
+ data_part_mock = MagicMock()
+ data_part_mock.root = data_part
+
+ parts = [text_part_mock, data_part_mock]
+ result = executor._convert_a2a_parts_to_content_blocks(parts)
+
+ assert len(result) == 2
+ assert result[0]["text"] == "Text content"
+ assert "[Structured Data]" in result[1]["text"]
+
+
@pytest.mark.asyncio
async def test_execute_streaming_mode_with_data_events(mock_strands_agent, mock_request_context, mock_event_queue):
"""Test that execute processes data events correctly in streaming mode."""
- async def mock_stream(user_input):
+ async def mock_stream(content_blocks):
"""Mock streaming function that yields data events."""
yield {"data": "First chunk"}
yield {"data": "Second chunk"}
yield {"result": MagicMock(spec=SAAgentResult)}
# Setup mock agent streaming
- mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input"))
+ mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([]))
# Create executor
executor = StrandsA2AExecutor(mock_strands_agent)
@@ -39,10 +326,25 @@ async def mock_stream(user_input):
mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
+ # Mock message with parts
+ from a2a.types import TextPart
+
+ mock_message = MagicMock()
+ text_part = MagicMock(spec=TextPart)
+ text_part.text = "Test input"
+ part = MagicMock()
+ part.root = text_part
+ mock_message.parts = [part]
+ mock_request_context.message = mock_message
+
await executor.execute(mock_request_context, mock_event_queue)
- # Verify agent was called with correct input
- mock_strands_agent.stream_async.assert_called_once_with("Test input")
+ # Verify agent was called with ContentBlock list
+ mock_strands_agent.stream_async.assert_called_once()
+ call_args = mock_strands_agent.stream_async.call_args[0][0]
+ assert isinstance(call_args, list)
+ assert len(call_args) == 1
+ assert call_args[0]["text"] == "Test input"
# Verify events were enqueued
mock_event_queue.enqueue_event.assert_called()
@@ -52,12 +354,12 @@ async def mock_stream(user_input):
async def test_execute_streaming_mode_with_result_event(mock_strands_agent, mock_request_context, mock_event_queue):
"""Test that execute processes result events correctly in streaming mode."""
- async def mock_stream(user_input):
+ async def mock_stream(content_blocks):
"""Mock streaming function that yields only result event."""
yield {"result": MagicMock(spec=SAAgentResult)}
# Setup mock agent streaming
- mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input"))
+ mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([]))
# Create executor
executor = StrandsA2AExecutor(mock_strands_agent)
@@ -68,10 +370,25 @@ async def mock_stream(user_input):
mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
+ # Mock message with parts
+ from a2a.types import TextPart
+
+ mock_message = MagicMock()
+ text_part = MagicMock(spec=TextPart)
+ text_part.text = "Test input"
+ part = MagicMock()
+ part.root = text_part
+ mock_message.parts = [part]
+ mock_request_context.message = mock_message
+
await executor.execute(mock_request_context, mock_event_queue)
- # Verify agent was called with correct input
- mock_strands_agent.stream_async.assert_called_once_with("Test input")
+ # Verify agent was called with ContentBlock list
+ mock_strands_agent.stream_async.assert_called_once()
+ call_args = mock_strands_agent.stream_async.call_args[0][0]
+ assert isinstance(call_args, list)
+ assert len(call_args) == 1
+ assert call_args[0]["text"] == "Test input"
# Verify events were enqueued
mock_event_queue.enqueue_event.assert_called()
@@ -81,13 +398,13 @@ async def mock_stream(user_input):
async def test_execute_streaming_mode_with_empty_data(mock_strands_agent, mock_request_context, mock_event_queue):
"""Test that execute handles empty data events correctly in streaming mode."""
- async def mock_stream(user_input):
+ async def mock_stream(content_blocks):
"""Mock streaming function that yields empty data."""
yield {"data": ""}
yield {"result": MagicMock(spec=SAAgentResult)}
# Setup mock agent streaming
- mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input"))
+ mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([]))
# Create executor
executor = StrandsA2AExecutor(mock_strands_agent)
@@ -98,10 +415,25 @@ async def mock_stream(user_input):
mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
+ # Mock message with parts
+ from a2a.types import TextPart
+
+ mock_message = MagicMock()
+ text_part = MagicMock(spec=TextPart)
+ text_part.text = "Test input"
+ part = MagicMock()
+ part.root = text_part
+ mock_message.parts = [part]
+ mock_request_context.message = mock_message
+
await executor.execute(mock_request_context, mock_event_queue)
- # Verify agent was called with correct input
- mock_strands_agent.stream_async.assert_called_once_with("Test input")
+ # Verify agent was called with ContentBlock list
+ mock_strands_agent.stream_async.assert_called_once()
+ call_args = mock_strands_agent.stream_async.call_args[0][0]
+ assert isinstance(call_args, list)
+ assert len(call_args) == 1
+ assert call_args[0]["text"] == "Test input"
# Verify events were enqueued
mock_event_queue.enqueue_event.assert_called()
@@ -111,13 +443,13 @@ async def mock_stream(user_input):
async def test_execute_streaming_mode_with_unexpected_event(mock_strands_agent, mock_request_context, mock_event_queue):
"""Test that execute handles unexpected events correctly in streaming mode."""
- async def mock_stream(user_input):
+ async def mock_stream(content_blocks):
"""Mock streaming function that yields unexpected event."""
yield {"unexpected": "event"}
yield {"result": MagicMock(spec=SAAgentResult)}
# Setup mock agent streaming
- mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input"))
+ mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([]))
# Create executor
executor = StrandsA2AExecutor(mock_strands_agent)
@@ -128,26 +460,69 @@ async def mock_stream(user_input):
mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
+ # Mock message with parts
+ from a2a.types import TextPart
+
+ mock_message = MagicMock()
+ text_part = MagicMock(spec=TextPart)
+ text_part.text = "Test input"
+ part = MagicMock()
+ part.root = text_part
+ mock_message.parts = [part]
+ mock_request_context.message = mock_message
+
await executor.execute(mock_request_context, mock_event_queue)
- # Verify agent was called with correct input
- mock_strands_agent.stream_async.assert_called_once_with("Test input")
+ # Verify agent was called with ContentBlock list
+ mock_strands_agent.stream_async.assert_called_once()
+ call_args = mock_strands_agent.stream_async.call_args[0][0]
+ assert isinstance(call_args, list)
+ assert len(call_args) == 1
+ assert call_args[0]["text"] == "Test input"
# Verify events were enqueued
mock_event_queue.enqueue_event.assert_called()
+@pytest.mark.asyncio
+async def test_execute_streaming_mode_fallback_to_text_extraction(
+ mock_strands_agent, mock_request_context, mock_event_queue
+):
+ """Test that execute raises ServerError when no A2A parts are available."""
+
+ # Create executor
+ executor = StrandsA2AExecutor(mock_strands_agent)
+
+ # Mock the task creation
+ mock_task = MagicMock()
+ mock_task.id = "test-task-id"
+ mock_task.context_id = "test-context-id"
+ mock_request_context.current_task = mock_task
+
+ # Mock message without parts attribute
+ mock_message = MagicMock()
+ delattr(mock_message, "parts") # Remove parts attribute
+ mock_request_context.message = mock_message
+ mock_request_context.get_user_input.return_value = "Fallback input"
+
+ with pytest.raises(ServerError) as excinfo:
+ await executor.execute(mock_request_context, mock_event_queue)
+
+ # Verify the error is a ServerError containing an InternalError
+ assert isinstance(excinfo.value.error, InternalError)
+
+
@pytest.mark.asyncio
async def test_execute_creates_task_when_none_exists(mock_strands_agent, mock_request_context, mock_event_queue):
"""Test that execute creates a new task when none exists."""
- async def mock_stream(user_input):
+ async def mock_stream(content_blocks):
"""Mock streaming function that yields data events."""
yield {"data": "Test chunk"}
yield {"result": MagicMock(spec=SAAgentResult)}
# Setup mock agent streaming
- mock_strands_agent.stream_async = MagicMock(return_value=mock_stream("Test input"))
+ mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([]))
# Create executor
executor = StrandsA2AExecutor(mock_strands_agent)
@@ -155,6 +530,17 @@ async def mock_stream(user_input):
# Mock no existing task
mock_request_context.current_task = None
+ # Mock message with parts
+ from a2a.types import TextPart
+
+ mock_message = MagicMock()
+ text_part = MagicMock(spec=TextPart)
+ text_part.text = "Test input"
+ part = MagicMock()
+ part.root = text_part
+ mock_message.parts = [part]
+ mock_request_context.message = mock_message
+
with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task:
mock_new_task.return_value = MagicMock(id="new-task-id", context_id="new-context-id")
@@ -183,11 +569,22 @@ async def test_execute_streaming_mode_handles_agent_exception(
mock_task.context_id = "test-context-id"
mock_request_context.current_task = mock_task
+ # Mock message with parts
+ from a2a.types import TextPart
+
+ mock_message = MagicMock()
+ text_part = MagicMock(spec=TextPart)
+ text_part.text = "Test input"
+ part = MagicMock()
+ part.root = text_part
+ mock_message.parts = [part]
+ mock_request_context.message = mock_message
+
with pytest.raises(ServerError):
await executor.execute(mock_request_context, mock_event_queue)
# Verify agent was called
- mock_strands_agent.stream_async.assert_called_once_with("Test input")
+ mock_strands_agent.stream_async.assert_called_once()
@pytest.mark.asyncio
@@ -252,3 +649,353 @@ async def test_handle_agent_result_with_result_but_no_message(
# Verify completion was called
mock_updater.complete.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_handle_agent_result_with_content(mock_strands_agent):
+ """Test that _handle_agent_result handles result with content correctly."""
+ executor = StrandsA2AExecutor(mock_strands_agent)
+
+ # Mock TaskUpdater
+ mock_updater = MagicMock()
+ mock_updater.complete = AsyncMock()
+ mock_updater.add_artifact = AsyncMock()
+
+ # Create result with content
+ mock_result = MagicMock(spec=SAAgentResult)
+ mock_result.__str__ = MagicMock(return_value="Test response content")
+
+ # Call _handle_agent_result
+ await executor._handle_agent_result(mock_result, mock_updater)
+
+ # Verify artifact was added and task completed
+ mock_updater.add_artifact.assert_called_once()
+ mock_updater.complete.assert_called_once()
+
+ # Check that the artifact contains the expected content
+ call_args = mock_updater.add_artifact.call_args[0][0]
+ assert len(call_args) == 1
+ assert call_args[0].root.text == "Test response content"
+
+
+def test_handle_conversion_error():
+ """Test that conversion handles errors gracefully."""
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Mock Part that will raise an exception during processing
+ problematic_part = MagicMock()
+ problematic_part.root = None # This should cause an AttributeError
+
+ # Should not raise an exception, but return empty list or handle gracefully
+ result = executor._convert_a2a_parts_to_content_blocks([problematic_part])
+
+ # The method should handle the error and continue
+ assert isinstance(result, list)
+
+
+def test_convert_a2a_parts_to_content_blocks_empty_list():
+ """Test conversion with empty parts list."""
+ executor = StrandsA2AExecutor(MagicMock())
+
+ result = executor._convert_a2a_parts_to_content_blocks([])
+
+ assert result == []
+
+
+def test_convert_a2a_parts_to_content_blocks_file_part_no_name():
+ """Test conversion of FilePart with no file name."""
+ from a2a.types import FilePart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Mock file object without name
+ file_obj = MagicMock()
+ delattr(file_obj, "name") # Remove name attribute
+ file_obj.mime_type = "text/plain"
+ file_obj.bytes = b"test content"
+ file_obj.uri = None
+
+ # Mock FilePart with proper spec
+ file_part = MagicMock(spec=FilePart)
+ file_part.file = file_obj
+
+ # Mock Part with FilePart root
+ part = MagicMock()
+ part.root = file_part
+
+ result = executor._convert_a2a_parts_to_content_blocks([part])
+
+ assert len(result) == 1
+ content_block = result[0]
+ assert "document" in content_block
+ assert content_block["document"]["name"] == "FileNameNotProvided" # Should use default
+
+
+def test_convert_a2a_parts_to_content_blocks_file_part_no_mime_type():
+ """Test conversion of FilePart with no MIME type."""
+ from a2a.types import FilePart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Mock file object without MIME type
+ file_obj = MagicMock()
+ file_obj.name = "test_file"
+ delattr(file_obj, "mime_type")
+ file_obj.bytes = b"test content"
+ file_obj.uri = None
+
+ # Mock FilePart with proper spec
+ file_part = MagicMock(spec=FilePart)
+ file_part.file = file_obj
+
+ # Mock Part with FilePart root
+ part = MagicMock()
+ part.root = file_part
+
+ result = executor._convert_a2a_parts_to_content_blocks([part])
+
+ assert len(result) == 1
+ content_block = result[0]
+ assert "document" in content_block # Should default to document with unknown type
+ assert content_block["document"]["format"] == "txt" # Should use default format for unknown file type
+
+
+def test_convert_a2a_parts_to_content_blocks_file_part_no_bytes_no_uri():
+ """Test conversion of FilePart with neither bytes nor URI."""
+ from a2a.types import FilePart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Mock file object without bytes or URI
+ file_obj = MagicMock()
+ file_obj.name = "test_file.txt"
+ file_obj.mime_type = "text/plain"
+ file_obj.bytes = None
+ file_obj.uri = None
+
+ # Mock FilePart with proper spec
+ file_part = MagicMock(spec=FilePart)
+ file_part.file = file_obj
+
+ # Mock Part with FilePart root
+ part = MagicMock()
+ part.root = file_part
+
+ result = executor._convert_a2a_parts_to_content_blocks([part])
+
+ # Should return empty list since no fallback case exists
+ assert len(result) == 0
+
+
+def test_convert_a2a_parts_to_content_blocks_data_part_serialization_error():
+ """Test conversion of DataPart with non-serializable data."""
+ from a2a.types import DataPart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Create non-serializable data (e.g., a function)
+ def non_serializable():
+ pass
+
+ # Mock DataPart with proper spec
+ data_part = MagicMock(spec=DataPart)
+ data_part.data = {"function": non_serializable} # This will cause JSON serialization to fail
+
+ # Mock Part with DataPart root
+ part = MagicMock()
+ part.root = data_part
+
+ # Should not raise an exception, should handle gracefully
+ result = executor._convert_a2a_parts_to_content_blocks([part])
+
+ # The error handling should result in an empty list or the part being skipped
+ assert isinstance(result, list)
+
+
+@pytest.mark.asyncio
+async def test_execute_streaming_mode_raises_error_for_empty_content_blocks(
+ mock_strands_agent, mock_event_queue, mock_request_context
+):
+ """Test that execute raises ServerError when content blocks are empty after conversion."""
+ executor = StrandsA2AExecutor(mock_strands_agent)
+
+ # Create a mock message with parts that will result in empty content blocks
+ # This could happen if all parts fail to convert or are invalid
+ mock_message = MagicMock()
+ mock_message.parts = [MagicMock()] # Has parts but they won't convert to valid content blocks
+ mock_request_context.message = mock_message
+
+ # Mock the conversion to return empty list
+ with patch.object(executor, "_convert_a2a_parts_to_content_blocks", return_value=[]):
+ with pytest.raises(ServerError) as excinfo:
+ await executor.execute(mock_request_context, mock_event_queue)
+
+ # Verify the error is a ServerError containing an InternalError
+ assert isinstance(excinfo.value.error, InternalError)
+
+
+@pytest.mark.asyncio
+async def test_execute_with_mixed_part_types(mock_strands_agent, mock_request_context, mock_event_queue):
+ """Test execute with a message containing mixed A2A part types."""
+ from a2a.types import DataPart, FilePart, TextPart
+
+ async def mock_stream(content_blocks):
+ """Mock streaming function."""
+ yield {"data": "Processing mixed content"}
+ yield {"result": MagicMock(spec=SAAgentResult)}
+
+ # Setup mock agent streaming
+ mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([]))
+
+ # Create executor
+ executor = StrandsA2AExecutor(mock_strands_agent)
+
+ # Mock the task creation
+ mock_task = MagicMock()
+ mock_task.id = "test-task-id"
+ mock_task.context_id = "test-context-id"
+ mock_request_context.current_task = mock_task
+
+ # Create mixed parts
+ text_part = MagicMock(spec=TextPart)
+ text_part.text = "Hello"
+ text_part_mock = MagicMock()
+ text_part_mock.root = text_part
+
+ # File part with bytes
+ file_obj = MagicMock()
+ file_obj.name = "image.png"
+ file_obj.mime_type = "image/png"
+ file_obj.bytes = b"fake_image"
+ file_obj.uri = None
+ file_part = MagicMock(spec=FilePart)
+ file_part.file = file_obj
+ file_part_mock = MagicMock()
+ file_part_mock.root = file_part
+
+ # Data part
+ data_part = MagicMock(spec=DataPart)
+ data_part.data = {"key": "value"}
+ data_part_mock = MagicMock()
+ data_part_mock.root = data_part
+
+ # Mock message with mixed parts
+ mock_message = MagicMock()
+ mock_message.parts = [text_part_mock, file_part_mock, data_part_mock]
+ mock_request_context.message = mock_message
+
+ await executor.execute(mock_request_context, mock_event_queue)
+
+ # Verify agent was called with ContentBlock list containing all types
+ mock_strands_agent.stream_async.assert_called_once()
+ call_args = mock_strands_agent.stream_async.call_args[0][0]
+ assert isinstance(call_args, list)
+ assert len(call_args) == 3 # Should have converted all 3 parts
+
+ # Check that we have text, image, and structured data
+ has_text = any("text" in block for block in call_args)
+ has_image = any("image" in block for block in call_args)
+ has_structured_data = any("text" in block and "[Structured Data]" in block.get("text", "") for block in call_args)
+
+ assert has_text
+ assert has_image
+ assert has_structured_data
+
+
+def test_integration_example():
+ """Integration test example showing how A2A Parts are converted to ContentBlocks.
+
+ This test serves as documentation for the conversion functionality.
+ """
+ from a2a.types import DataPart, FilePart, TextPart
+
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Example 1: Text content
+ text_part = MagicMock(spec=TextPart)
+ text_part.text = "Hello, this is a text message"
+ text_part_mock = MagicMock()
+ text_part_mock.root = text_part
+
+ # Example 2: Image file
+ image_bytes = b"fake_image_content"
+ image_file = MagicMock()
+ image_file.name = "photo.jpg"
+ image_file.mime_type = "image/jpeg"
+ image_file.bytes = image_bytes
+ image_file.uri = None
+
+ image_part = MagicMock(spec=FilePart)
+ image_part.file = image_file
+ image_part_mock = MagicMock()
+ image_part_mock.root = image_part
+
+ # Example 3: Document file
+ doc_bytes = b"PDF document content"
+ doc_file = MagicMock()
+ doc_file.name = "report.pdf"
+ doc_file.mime_type = "application/pdf"
+ doc_file.bytes = doc_bytes
+ doc_file.uri = None
+
+ doc_part = MagicMock(spec=FilePart)
+ doc_part.file = doc_file
+ doc_part_mock = MagicMock()
+ doc_part_mock.root = doc_part
+
+ # Example 4: Structured data
+ data_part = MagicMock(spec=DataPart)
+ data_part.data = {"user": "john_doe", "action": "upload_file", "timestamp": "2023-12-01T10:00:00Z"}
+ data_part_mock = MagicMock()
+ data_part_mock.root = data_part
+
+ # Convert all parts to ContentBlocks
+ parts = [text_part_mock, image_part_mock, doc_part_mock, data_part_mock]
+ content_blocks = executor._convert_a2a_parts_to_content_blocks(parts)
+
+ # Verify conversion results
+ assert len(content_blocks) == 4
+
+ # Text part becomes text ContentBlock
+ assert content_blocks[0]["text"] == "Hello, this is a text message"
+
+ # Image part becomes image ContentBlock with proper format and bytes
+ assert "image" in content_blocks[1]
+ assert content_blocks[1]["image"]["format"] == "jpeg"
+ assert content_blocks[1]["image"]["source"]["bytes"] == image_bytes
+
+ # Document part becomes document ContentBlock
+ assert "document" in content_blocks[2]
+ assert content_blocks[2]["document"]["format"] == "pdf"
+ assert content_blocks[2]["document"]["name"] == "report" # Extension stripped
+ assert content_blocks[2]["document"]["source"]["bytes"] == doc_bytes
+
+ # Data part becomes text ContentBlock with JSON representation
+ assert "text" in content_blocks[3]
+ assert "[Structured Data]" in content_blocks[3]["text"]
+ assert "john_doe" in content_blocks[3]["text"]
+ assert "upload_file" in content_blocks[3]["text"]
+
+
+def test_default_formats_modularization():
+ """Test that DEFAULT_FORMATS mapping works correctly for modular format defaults."""
+ executor = StrandsA2AExecutor(MagicMock())
+
+ # Test that DEFAULT_FORMATS contains expected mappings
+ assert hasattr(executor, "DEFAULT_FORMATS")
+ assert executor.DEFAULT_FORMATS["document"] == "txt"
+ assert executor.DEFAULT_FORMATS["image"] == "png"
+ assert executor.DEFAULT_FORMATS["video"] == "mp4"
+ assert executor.DEFAULT_FORMATS["unknown"] == "txt"
+
+ # Test format selection with None mime_type
+ assert executor._get_file_format_from_mime_type(None, "document") == "txt"
+ assert executor._get_file_format_from_mime_type(None, "image") == "png"
+ assert executor._get_file_format_from_mime_type(None, "video") == "mp4"
+ assert executor._get_file_format_from_mime_type(None, "unknown") == "txt"
+ assert executor._get_file_format_from_mime_type(None, "nonexistent") == "txt" # fallback
+
+ # Test format selection with empty mime_type
+ assert executor._get_file_format_from_mime_type("", "document") == "txt"
+ assert executor._get_file_format_from_mime_type("", "image") == "png"
+ assert executor._get_file_format_from_mime_type("", "video") == "mp4"
From 60dcb454c550002379444c698867b3f5e49fd490 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Wed, 20 Aug 2025 15:17:59 -0400
Subject: [PATCH 043/104] ci: update pre-commit requirement from <4.2.0,>=3.2.0
to >=3.2.0,<4.4.0 (#706)
Updates the requirements on [pre-commit](https://github.com/pre-commit/pre-commit) to permit the latest version.
- [Release notes](https://github.com/pre-commit/pre-commit/releases)
- [Changelog](https://github.com/pre-commit/pre-commit/blob/main/CHANGELOG.md)
- [Commits](https://github.com/pre-commit/pre-commit/compare/v3.2.0...v4.3.0)
---
updated-dependencies:
- dependency-name: pre-commit
dependency-version: 4.3.0
dependency-type: direct:production
...
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 847db8d2b..de28c311c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -55,7 +55,7 @@ dev = [
"hatch>=1.0.0,<2.0.0",
"moto>=5.1.0,<6.0.0",
"mypy>=1.15.0,<2.0.0",
- "pre-commit>=3.2.0,<4.2.0",
+ "pre-commit>=3.2.0,<4.4.0",
"pytest>=8.0.0,<9.0.0",
"pytest-asyncio>=0.26.0,<0.27.0",
"pytest-cov>=4.1.0,<5.0.0",
From b61a06416b250693f162cd490b941643cdbefbc5 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Thu, 21 Aug 2025 09:39:50 -0400
Subject: [PATCH 044/104] ci: update ruff requirement from <0.5.0,>=0.4.4 to
>=0.4.4,<0.13.0 (#704)
* ci: update ruff requirement from <0.5.0,>=0.4.4 to >=0.4.4,<0.13.0
Updates the requirements on [ruff](https://github.com/astral-sh/ruff) to permit the latest version.
- [Release notes](https://github.com/astral-sh/ruff/releases)
- [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md)
- [Commits](https://github.com/astral-sh/ruff/compare/v0.4.4...0.12.9)
---
updated-dependencies:
- dependency-name: ruff
dependency-version: 0.12.9
dependency-type: direct:production
...
Signed-off-by: dependabot[bot]
* Apply suggestions from code review
Co-authored-by: Patrick Gray
---------
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jonathan Segev
Co-authored-by: Patrick Gray
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index de28c311c..124ba5653 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -60,7 +60,7 @@ dev = [
"pytest-asyncio>=0.26.0,<0.27.0",
"pytest-cov>=4.1.0,<5.0.0",
"pytest-xdist>=3.0.0,<4.0.0",
- "ruff>=0.4.4,<0.5.0",
+ "ruff>=0.12.0,<0.13.0",
]
docs = [
"sphinx>=5.0.0,<6.0.0",
From 93d3ac83573d6085e02b165541b55c8da3d10bce Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Thu, 21 Aug 2025 09:40:01 -0400
Subject: [PATCH 045/104] ci: update pytest-asyncio requirement from
<0.27.0,>=0.26.0 to >=0.26.0,<1.2.0 (#708)
* ci: update pytest-asyncio requirement
Updates the requirements on [pytest-asyncio](https://github.com/pytest-dev/pytest-asyncio) to permit the latest version.
- [Release notes](https://github.com/pytest-dev/pytest-asyncio/releases)
- [Commits](https://github.com/pytest-dev/pytest-asyncio/compare/v0.26.0...v1.1.0)
---
updated-dependencies:
- dependency-name: pytest-asyncio
dependency-version: 1.1.0
dependency-type: direct:production
...
Signed-off-by: dependabot[bot]
* Apply suggestions from code review
Co-authored-by: Patrick Gray
---------
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jonathan Segev
Co-authored-by: Patrick Gray
---
pyproject.toml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 124ba5653..f91454414 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -57,7 +57,7 @@ dev = [
"mypy>=1.15.0,<2.0.0",
"pre-commit>=3.2.0,<4.4.0",
"pytest>=8.0.0,<9.0.0",
- "pytest-asyncio>=0.26.0,<0.27.0",
+ "pytest-asyncio>=1.0.0,<1.2.0",
"pytest-cov>=4.1.0,<5.0.0",
"pytest-xdist>=3.0.0,<4.0.0",
"ruff>=0.12.0,<0.13.0",
@@ -143,7 +143,7 @@ features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mis
extra-dependencies = [
"moto>=5.1.0,<6.0.0",
"pytest>=8.0.0,<9.0.0",
- "pytest-asyncio>=0.26.0,<0.27.0",
+ "pytest-asyncio>=1.0.0,<1.2.0",
"pytest-cov>=4.1.0,<5.0.0",
"pytest-xdist>=3.0.0,<4.0.0",
]
From 9397f58a953b83a7190e686ac6e29fa6d4e8ac86 Mon Sep 17 00:00:00 2001
From: Xwei
Date: Thu, 21 Aug 2025 22:19:13 +0800
Subject: [PATCH 046/104] fix: add system_prompt to structured_output_span
before adding input_messages (#709)
* fix: add system_prompt to structured_output_span before adding input_messages
* test: Add system message ordering validation to agent structured output test
* Switch to ensuring exact ordering of messages
---------
Co-authored-by: Dennis Tsai (RD-AS)
Co-authored-by: Mackenzie Zastrow
---
src/strands/agent/agent.py | 10 +++++-----
tests/strands/agent/test_agent.py | 25 +++++++++++++++++--------
2 files changed, 22 insertions(+), 13 deletions(-)
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index acc6a7650..5150060c6 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -470,16 +470,16 @@ async def structured_output_async(
"gen_ai.operation.name": "execute_structured_output",
}
)
- for message in temp_messages:
- structured_output_span.add_event(
- f"gen_ai.{message['role']}.message",
- attributes={"role": message["role"], "content": serialize(message["content"])},
- )
if self.system_prompt:
structured_output_span.add_event(
"gen_ai.system.message",
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
)
+ for message in temp_messages:
+ structured_output_span.add_event(
+ f"gen_ai.{message['role']}.message",
+ attributes={"role": message["role"], "content": serialize(message["content"])},
+ )
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
async for event in events:
if "callback" in event:
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index 444232455..7e769c6d7 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -18,6 +18,7 @@
from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel
from strands.session.repository_session_manager import RepositorySessionManager
+from strands.telemetry.tracer import serialize
from strands.types.content import Messages
from strands.types.exceptions import ContextWindowOverflowException, EventLoopException
from strands.types.session import Session, SessionAgent, SessionMessage, SessionType
@@ -1028,15 +1029,23 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator):
}
)
- mock_span.add_event.assert_any_call(
- "gen_ai.user.message",
- attributes={"role": "user", "content": '[{"text": "Jane Doe is 30 years old and her email is jane@doe.com"}]'},
- )
+ # ensure correct otel event messages are emitted
+ act_event_names = mock_span.add_event.call_args_list
+ exp_event_names = [
+ unittest.mock.call(
+ "gen_ai.system.message", attributes={"role": "system", "content": serialize([{"text": system_prompt}])}
+ ),
+ unittest.mock.call(
+ "gen_ai.user.message",
+ attributes={
+ "role": "user",
+ "content": '[{"text": "Jane Doe is 30 years old and her email is jane@doe.com"}]',
+ },
+ ),
+ unittest.mock.call("gen_ai.choice", attributes={"message": json.dumps(user.model_dump())}),
+ ]
- mock_span.add_event.assert_called_with(
- "gen_ai.choice",
- attributes={"message": json.dumps(user.model_dump())},
- )
+ assert act_event_names == exp_event_names
def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, agenerator):
From 6ef64478d7fde3c677ea13cadf068422a3d01377 Mon Sep 17 00:00:00 2001
From: Murat Kaan Meral
Date: Mon, 25 Aug 2025 16:16:50 +0300
Subject: [PATCH 047/104] feat(multiagent): Add __call__ implementation to
MultiAgentBase (#645)
---
src/strands/multiagent/base.py | 11 +++++++--
tests/strands/multiagent/test_base.py | 34 +++++++++++++++++++++++----
2 files changed, 39 insertions(+), 6 deletions(-)
diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py
index c6b1af702..69578cb5d 100644
--- a/src/strands/multiagent/base.py
+++ b/src/strands/multiagent/base.py
@@ -3,7 +3,9 @@
Provides minimal foundation for multi-agent patterns (Swarm, Graph).
"""
+import asyncio
from abc import ABC, abstractmethod
+from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Union
@@ -86,7 +88,12 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> M
"""Invoke asynchronously."""
raise NotImplementedError("invoke_async not implemented")
- @abstractmethod
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
"""Invoke synchronously."""
- raise NotImplementedError("__call__ not implemented")
+
+ def execute() -> MultiAgentResult:
+ return asyncio.run(self.invoke_async(task, **kwargs))
+
+ with ThreadPoolExecutor() as executor:
+ future = executor.submit(execute)
+ return future.result()
diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py
index 7aa76bb90..395d9275c 100644
--- a/tests/strands/multiagent/test_base.py
+++ b/tests/strands/multiagent/test_base.py
@@ -141,9 +141,35 @@ class CompleteMultiAgent(MultiAgentBase):
async def invoke_async(self, task: str) -> MultiAgentResult:
return MultiAgentResult(results={})
- def __call__(self, task: str) -> MultiAgentResult:
- return MultiAgentResult(results={})
-
- # Should not raise an exception
+ # Should not raise an exception - __call__ is provided by base class
agent = CompleteMultiAgent()
assert isinstance(agent, MultiAgentBase)
+
+
+def test_multi_agent_base_call_method():
+ """Test that __call__ method properly delegates to invoke_async."""
+
+ class TestMultiAgent(MultiAgentBase):
+ def __init__(self):
+ self.invoke_async_called = False
+ self.received_task = None
+ self.received_kwargs = None
+
+ async def invoke_async(self, task, **kwargs):
+ self.invoke_async_called = True
+ self.received_task = task
+ self.received_kwargs = kwargs
+ return MultiAgentResult(
+ status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)}
+ )
+
+ agent = TestMultiAgent()
+
+ # Test with string task
+ result = agent("test task", param1="value1", param2="value2")
+
+ assert agent.invoke_async_called
+ assert agent.received_task == "test task"
+ assert agent.received_kwargs == {"param1": "value1", "param2": "value2"}
+ assert isinstance(result, MultiAgentResult)
+ assert result.status == Status.COMPLETED
From e4879e18121985b860d0f9e3556c0bf7e512a4a7 Mon Sep 17 00:00:00 2001
From: mehtarac
Date: Mon, 25 Aug 2025 06:40:26 -0700
Subject: [PATCH 048/104] chore: Update pydantic minimum version (#723)
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index f91454414..32de94aa6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,7 +30,7 @@ dependencies = [
"botocore>=1.29.0,<2.0.0",
"docstring_parser>=0.15,<1.0",
"mcp>=1.11.0,<2.0.0",
- "pydantic>=2.0.0,<3.0.0",
+ "pydantic>=2.4.0,<3.0.0",
"typing-extensions>=4.13.2,<5.0.0",
"watchdog>=6.0.0,<7.0.0",
"opentelemetry-api>=1.30.0,<2.0.0",
From c18ef930ee7c436f7af58845001a2e02014b52da Mon Sep 17 00:00:00 2001
From: Patrick Gray
Date: Mon, 25 Aug 2025 10:16:04 -0400
Subject: [PATCH 049/104] tool executors (#658)
---
src/strands/agent/agent.py | 15 +-
src/strands/event_loop/event_loop.py | 155 +-----
src/strands/tools/_validator.py | 45 ++
src/strands/tools/executor.py | 137 ------
src/strands/tools/executors/__init__.py | 16 +
src/strands/tools/executors/_executor.py | 227 +++++++++
src/strands/tools/executors/concurrent.py | 113 +++++
src/strands/tools/executors/sequential.py | 46 ++
tests/strands/agent/test_agent.py | 44 +-
tests/strands/event_loop/test_event_loop.py | 271 +----------
tests/strands/tools/executors/conftest.py | 116 +++++
.../tools/executors/test_concurrent.py | 32 ++
.../strands/tools/executors/test_executor.py | 144 ++++++
.../tools/executors/test_sequential.py | 32 ++
tests/strands/tools/test_executor.py | 440 ------------------
tests/strands/tools/test_validator.py | 50 ++
.../tools/executors/test_concurrent.py | 61 +++
.../tools/executors/test_sequential.py | 61 +++
18 files changed, 985 insertions(+), 1020 deletions(-)
create mode 100644 src/strands/tools/_validator.py
delete mode 100644 src/strands/tools/executor.py
create mode 100644 src/strands/tools/executors/__init__.py
create mode 100644 src/strands/tools/executors/_executor.py
create mode 100644 src/strands/tools/executors/concurrent.py
create mode 100644 src/strands/tools/executors/sequential.py
create mode 100644 tests/strands/tools/executors/conftest.py
create mode 100644 tests/strands/tools/executors/test_concurrent.py
create mode 100644 tests/strands/tools/executors/test_executor.py
create mode 100644 tests/strands/tools/executors/test_sequential.py
delete mode 100644 tests/strands/tools/test_executor.py
create mode 100644 tests/strands/tools/test_validator.py
create mode 100644 tests_integ/tools/executors/test_concurrent.py
create mode 100644 tests_integ/tools/executors/test_sequential.py
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index 5150060c6..adc554bf4 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -20,7 +20,7 @@
from pydantic import BaseModel
from .. import _identifier
-from ..event_loop.event_loop import event_loop_cycle, run_tool
+from ..event_loop.event_loop import event_loop_cycle
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
from ..hooks import (
AfterInvocationEvent,
@@ -35,6 +35,8 @@
from ..session.session_manager import SessionManager
from ..telemetry.metrics import EventLoopMetrics
from ..telemetry.tracer import get_tracer, serialize
+from ..tools.executors import ConcurrentToolExecutor
+from ..tools.executors._executor import ToolExecutor
from ..tools.registry import ToolRegistry
from ..tools.watcher import ToolWatcher
from ..types.content import ContentBlock, Message, Messages
@@ -136,13 +138,14 @@ def caller(
"name": normalized_name,
"input": kwargs.copy(),
}
+ tool_results: list[ToolResult] = []
+ invocation_state = kwargs
async def acall() -> ToolResult:
- # Pass kwargs as invocation_state
- async for event in run_tool(self._agent, tool_use, kwargs):
+ async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
_ = event
- return cast(ToolResult, event)
+ return tool_results[0]
def tcall() -> ToolResult:
return asyncio.run(acall())
@@ -208,6 +211,7 @@ def __init__(
state: Optional[Union[AgentState, dict]] = None,
hooks: Optional[list[HookProvider]] = None,
session_manager: Optional[SessionManager] = None,
+ tool_executor: Optional[ToolExecutor] = None,
):
"""Initialize the Agent with the specified configuration.
@@ -250,6 +254,7 @@ def __init__(
Defaults to None.
session_manager: Manager for handling agent sessions including conversation history and state.
If provided, enables session-based persistence and state management.
+ tool_executor: Definition of tool execution stragety (e.g., sequential, concurrent, etc.).
Raises:
ValueError: If agent id contains path separators.
@@ -324,6 +329,8 @@ def __init__(
if self._session_manager:
self.hooks.add_hook(self._session_manager)
+ self.tool_executor = tool_executor or ConcurrentToolExecutor()
+
if hooks:
for hook in hooks:
self.hooks.add_hook(hook)
diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py
index b36f73155..524ecc3e8 100644
--- a/src/strands/event_loop/event_loop.py
+++ b/src/strands/event_loop/event_loop.py
@@ -11,22 +11,20 @@
import logging
import time
import uuid
-from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
+from typing import TYPE_CHECKING, Any, AsyncGenerator
from opentelemetry import trace as trace_api
from ..experimental.hooks import (
AfterModelInvocationEvent,
- AfterToolInvocationEvent,
BeforeModelInvocationEvent,
- BeforeToolInvocationEvent,
)
from ..hooks import (
MessageAddedEvent,
)
from ..telemetry.metrics import Trace
from ..telemetry.tracer import get_tracer
-from ..tools.executor import run_tools, validate_and_prepare_tools
+from ..tools._validator import validate_and_prepare_tools
from ..types.content import Message
from ..types.exceptions import (
ContextWindowOverflowException,
@@ -35,7 +33,7 @@
ModelThrottledException,
)
from ..types.streaming import Metrics, StopReason
-from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
+from ..types.tools import ToolResult, ToolUse
from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached
from .streaming import stream_messages
@@ -212,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
if stop_reason == "max_tokens":
"""
Handle max_tokens limit reached by the model.
-
+
When the model reaches its maximum token limit, this represents a potentially unrecoverable
state where the model's response was truncated. By default, Strands fails hard with an
MaxTokensReachedException to maintain consistency with other failure types.
@@ -306,122 +304,6 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -
recursive_trace.end()
-async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolGenerator:
- """Process a tool invocation.
-
- Looks up the tool in the registry and streams it with the provided parameters.
-
- Args:
- agent: The agent for which the tool is being executed.
- tool_use: The tool object to process, containing name and parameters.
- invocation_state: Context for the tool invocation, including agent state.
-
- Yields:
- Tool events with the last being the tool result.
- """
- logger.debug("tool_use=<%s> | streaming", tool_use)
- tool_name = tool_use["name"]
-
- # Get the tool info
- tool_info = agent.tool_registry.dynamic_tools.get(tool_name)
- tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name)
-
- # Add standard arguments to invocation_state for Python tools
- invocation_state.update(
- {
- "model": agent.model,
- "system_prompt": agent.system_prompt,
- "messages": agent.messages,
- "tool_config": ToolConfig( # for backwards compatability
- tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()],
- toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}),
- ),
- }
- )
-
- before_event = agent.hooks.invoke_callbacks(
- BeforeToolInvocationEvent(
- agent=agent,
- selected_tool=tool_func,
- tool_use=tool_use,
- invocation_state=invocation_state,
- )
- )
-
- try:
- selected_tool = before_event.selected_tool
- tool_use = before_event.tool_use
- invocation_state = before_event.invocation_state # Get potentially modified invocation_state from hook
-
- # Check if tool exists
- if not selected_tool:
- if tool_func == selected_tool:
- logger.error(
- "tool_name=<%s>, available_tools=<%s> | tool not found in registry",
- tool_name,
- list(agent.tool_registry.registry.keys()),
- )
- else:
- logger.debug(
- "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call",
- tool_name,
- str(tool_use.get("toolUseId")),
- )
-
- result: ToolResult = {
- "toolUseId": str(tool_use.get("toolUseId")),
- "status": "error",
- "content": [{"text": f"Unknown tool: {tool_name}"}],
- }
- # for every Before event call, we need to have an AfterEvent call
- after_event = agent.hooks.invoke_callbacks(
- AfterToolInvocationEvent(
- agent=agent,
- selected_tool=selected_tool,
- tool_use=tool_use,
- invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks
- result=result,
- )
- )
- yield after_event.result
- return
-
- async for event in selected_tool.stream(tool_use, invocation_state):
- yield event
-
- result = event
-
- after_event = agent.hooks.invoke_callbacks(
- AfterToolInvocationEvent(
- agent=agent,
- selected_tool=selected_tool,
- tool_use=tool_use,
- invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks
- result=result,
- )
- )
- yield after_event.result
-
- except Exception as e:
- logger.exception("tool_name=<%s> | failed to process tool", tool_name)
- error_result: ToolResult = {
- "toolUseId": str(tool_use.get("toolUseId")),
- "status": "error",
- "content": [{"text": f"Error: {str(e)}"}],
- }
- after_event = agent.hooks.invoke_callbacks(
- AfterToolInvocationEvent(
- agent=agent,
- selected_tool=selected_tool,
- tool_use=tool_use,
- invocation_state=invocation_state, # Keep as invocation_state for backward compatibility with hooks
- result=error_result,
- exception=e,
- )
- )
- yield after_event.result
-
-
async def _handle_tool_execution(
stop_reason: StopReason,
message: Message,
@@ -431,18 +313,12 @@ async def _handle_tool_execution(
cycle_start_time: float,
invocation_state: dict[str, Any],
) -> AsyncGenerator[dict[str, Any], None]:
- tool_uses: list[ToolUse] = []
- tool_results: list[ToolResult] = []
- invalid_tool_use_ids: list[str] = []
-
- """
- Handles the execution of tools requested by the model during an event loop cycle.
+ """Handles the execution of tools requested by the model during an event loop cycle.
Args:
stop_reason: The reason the model stopped generating.
message: The message from the model that may contain tool use requests.
- event_loop_metrics: Metrics tracking object for the event loop.
- event_loop_parent_span: Span for the parent of this event loop.
+ agent: Agent for which tools are being executed.
cycle_trace: Trace object for the current event loop cycle.
cycle_span: Span object for tracing the cycle (type may vary).
cycle_start_time: Start time of the current cycle.
@@ -456,23 +332,18 @@ async def _handle_tool_execution(
- The updated event loop metrics,
- The updated request state.
"""
- validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
+ tool_uses: list[ToolUse] = []
+ tool_results: list[ToolResult] = []
+ invalid_tool_use_ids: list[str] = []
+ validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
+ tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
if not tool_uses:
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
return
- def tool_handler(tool_use: ToolUse) -> ToolGenerator:
- return run_tool(agent, tool_use, invocation_state)
-
- tool_events = run_tools(
- handler=tool_handler,
- tool_uses=tool_uses,
- event_loop_metrics=agent.event_loop_metrics,
- invalid_tool_use_ids=invalid_tool_use_ids,
- tool_results=tool_results,
- cycle_trace=cycle_trace,
- parent_span=cycle_span,
+ tool_events = agent.tool_executor._execute(
+ agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state
)
async for tool_event in tool_events:
yield tool_event
diff --git a/src/strands/tools/_validator.py b/src/strands/tools/_validator.py
new file mode 100644
index 000000000..77aa57e87
--- /dev/null
+++ b/src/strands/tools/_validator.py
@@ -0,0 +1,45 @@
+"""Tool validation utilities."""
+
+from ..tools.tools import InvalidToolUseNameException, validate_tool_use
+from ..types.content import Message
+from ..types.tools import ToolResult, ToolUse
+
+
+def validate_and_prepare_tools(
+ message: Message,
+ tool_uses: list[ToolUse],
+ tool_results: list[ToolResult],
+ invalid_tool_use_ids: list[str],
+) -> None:
+ """Validate tool uses and prepare them for execution.
+
+ Args:
+ message: Current message.
+ tool_uses: List to populate with tool uses.
+ tool_results: List to populate with tool results for invalid tools.
+ invalid_tool_use_ids: List to populate with invalid tool use IDs.
+ """
+ # Extract tool uses from message
+ for content in message["content"]:
+ if isinstance(content, dict) and "toolUse" in content:
+ tool_uses.append(content["toolUse"])
+
+ # Validate tool uses
+ # Avoid modifying original `tool_uses` variable during iteration
+ tool_uses_copy = tool_uses.copy()
+ for tool in tool_uses_copy:
+ try:
+ validate_tool_use(tool)
+ except InvalidToolUseNameException as e:
+ # Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context
+ tool_uses.remove(tool)
+ tool["name"] = "INVALID_TOOL_NAME"
+ invalid_tool_use_ids.append(tool["toolUseId"])
+ tool_uses.append(tool)
+ tool_results.append(
+ {
+ "toolUseId": tool["toolUseId"],
+ "status": "error",
+ "content": [{"text": f"Error: {str(e)}"}],
+ }
+ )
diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py
deleted file mode 100644
index d90f9a5aa..000000000
--- a/src/strands/tools/executor.py
+++ /dev/null
@@ -1,137 +0,0 @@
-"""Tool execution functionality for the event loop."""
-
-import asyncio
-import logging
-import time
-from typing import Any, Optional, cast
-
-from opentelemetry import trace as trace_api
-
-from ..telemetry.metrics import EventLoopMetrics, Trace
-from ..telemetry.tracer import get_tracer
-from ..tools.tools import InvalidToolUseNameException, validate_tool_use
-from ..types.content import Message
-from ..types.tools import RunToolHandler, ToolGenerator, ToolResult, ToolUse
-
-logger = logging.getLogger(__name__)
-
-
-async def run_tools(
- handler: RunToolHandler,
- tool_uses: list[ToolUse],
- event_loop_metrics: EventLoopMetrics,
- invalid_tool_use_ids: list[str],
- tool_results: list[ToolResult],
- cycle_trace: Trace,
- parent_span: Optional[trace_api.Span] = None,
-) -> ToolGenerator:
- """Execute tools concurrently.
-
- Args:
- handler: Tool handler processing function.
- tool_uses: List of tool uses to execute.
- event_loop_metrics: Metrics collection object.
- invalid_tool_use_ids: List of invalid tool use IDs.
- tool_results: List to populate with tool results.
- cycle_trace: Parent trace for the current cycle.
- parent_span: Parent span for the current cycle.
-
- Yields:
- Events of the tool stream. Tool results are appended to `tool_results`.
- """
-
- async def work(
- tool_use: ToolUse,
- worker_id: int,
- worker_queue: asyncio.Queue,
- worker_event: asyncio.Event,
- stop_event: object,
- ) -> ToolResult:
- tracer = get_tracer()
- tool_call_span = tracer.start_tool_call_span(tool_use, parent_span)
-
- tool_name = tool_use["name"]
- tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
- tool_start_time = time.time()
- with trace_api.use_span(tool_call_span):
- try:
- async for event in handler(tool_use):
- worker_queue.put_nowait((worker_id, event))
- await worker_event.wait()
- worker_event.clear()
-
- result = cast(ToolResult, event)
- finally:
- worker_queue.put_nowait((worker_id, stop_event))
-
- tool_success = result.get("status") == "success"
- tool_duration = time.time() - tool_start_time
- message = Message(role="user", content=[{"toolResult": result}])
- event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message)
- cycle_trace.add_child(tool_trace)
-
- tracer.end_tool_call_span(tool_call_span, result)
-
- return result
-
- tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
- worker_queue: asyncio.Queue[tuple[int, Any]] = asyncio.Queue()
- worker_events = [asyncio.Event() for _ in tool_uses]
- stop_event = object()
-
- workers = [
- asyncio.create_task(work(tool_use, worker_id, worker_queue, worker_events[worker_id], stop_event))
- for worker_id, tool_use in enumerate(tool_uses)
- ]
-
- worker_count = len(workers)
- while worker_count:
- worker_id, event = await worker_queue.get()
- if event is stop_event:
- worker_count -= 1
- continue
-
- yield event
- worker_events[worker_id].set()
-
- tool_results.extend([worker.result() for worker in workers])
-
-
-def validate_and_prepare_tools(
- message: Message,
- tool_uses: list[ToolUse],
- tool_results: list[ToolResult],
- invalid_tool_use_ids: list[str],
-) -> None:
- """Validate tool uses and prepare them for execution.
-
- Args:
- message: Current message.
- tool_uses: List to populate with tool uses.
- tool_results: List to populate with tool results for invalid tools.
- invalid_tool_use_ids: List to populate with invalid tool use IDs.
- """
- # Extract tool uses from message
- for content in message["content"]:
- if isinstance(content, dict) and "toolUse" in content:
- tool_uses.append(content["toolUse"])
-
- # Validate tool uses
- # Avoid modifying original `tool_uses` variable during iteration
- tool_uses_copy = tool_uses.copy()
- for tool in tool_uses_copy:
- try:
- validate_tool_use(tool)
- except InvalidToolUseNameException as e:
- # Replace the invalid toolUse name and return invalid name error as ToolResult to the LLM as context
- tool_uses.remove(tool)
- tool["name"] = "INVALID_TOOL_NAME"
- invalid_tool_use_ids.append(tool["toolUseId"])
- tool_uses.append(tool)
- tool_results.append(
- {
- "toolUseId": tool["toolUseId"],
- "status": "error",
- "content": [{"text": f"Error: {str(e)}"}],
- }
- )
diff --git a/src/strands/tools/executors/__init__.py b/src/strands/tools/executors/__init__.py
new file mode 100644
index 000000000..c8be812e4
--- /dev/null
+++ b/src/strands/tools/executors/__init__.py
@@ -0,0 +1,16 @@
+"""Tool executors for the Strands SDK.
+
+This package provides different execution strategies for tools, allowing users to customize
+how tools are executed (e.g., concurrent, sequential, with custom thread pools, etc.).
+"""
+
+from . import concurrent, sequential
+from .concurrent import ConcurrentToolExecutor
+from .sequential import SequentialToolExecutor
+
+__all__ = [
+ "ConcurrentToolExecutor",
+ "SequentialToolExecutor",
+ "concurrent",
+ "sequential",
+]
diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py
new file mode 100644
index 000000000..9999b77fc
--- /dev/null
+++ b/src/strands/tools/executors/_executor.py
@@ -0,0 +1,227 @@
+"""Abstract base class for tool executors.
+
+Tool executors are responsible for determining how tools are executed (e.g., concurrently, sequentially, with custom
+thread pools, etc.).
+"""
+
+import abc
+import logging
+import time
+from typing import TYPE_CHECKING, Any, cast
+
+from opentelemetry import trace as trace_api
+
+from ...experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
+from ...telemetry.metrics import Trace
+from ...telemetry.tracer import get_tracer
+from ...types.content import Message
+from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
+
+if TYPE_CHECKING: # pragma: no cover
+ from ...agent import Agent
+
+logger = logging.getLogger(__name__)
+
+
+class ToolExecutor(abc.ABC):
+ """Abstract base class for tool executors."""
+
+ @staticmethod
+ async def _stream(
+ agent: "Agent",
+ tool_use: ToolUse,
+ tool_results: list[ToolResult],
+ invocation_state: dict[str, Any],
+ **kwargs: Any,
+ ) -> ToolGenerator:
+ """Stream tool events.
+
+ This method adds additional logic to the stream invocation including:
+
+ - Tool lookup and validation
+ - Before/after hook execution
+ - Tracing and metrics collection
+ - Error handling and recovery
+
+ Args:
+ agent: The agent for which the tool is being executed.
+ tool_use: Metadata and inputs for the tool to be executed.
+ tool_results: List of tool results from each tool execution.
+ invocation_state: Context for the tool invocation.
+ **kwargs: Additional keyword arguments for future extensibility.
+
+ Yields:
+ Tool events with the last being the tool result.
+ """
+ logger.debug("tool_use=<%s> | streaming", tool_use)
+ tool_name = tool_use["name"]
+
+ tool_info = agent.tool_registry.dynamic_tools.get(tool_name)
+ tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name)
+
+ invocation_state.update(
+ {
+ "model": agent.model,
+ "messages": agent.messages,
+ "system_prompt": agent.system_prompt,
+ "tool_config": ToolConfig( # for backwards compatibility
+ tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()],
+ toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}),
+ ),
+ }
+ )
+
+ before_event = agent.hooks.invoke_callbacks(
+ BeforeToolInvocationEvent(
+ agent=agent,
+ selected_tool=tool_func,
+ tool_use=tool_use,
+ invocation_state=invocation_state,
+ )
+ )
+
+ try:
+ selected_tool = before_event.selected_tool
+ tool_use = before_event.tool_use
+ invocation_state = before_event.invocation_state
+
+ if not selected_tool:
+ if tool_func == selected_tool:
+ logger.error(
+ "tool_name=<%s>, available_tools=<%s> | tool not found in registry",
+ tool_name,
+ list(agent.tool_registry.registry.keys()),
+ )
+ else:
+ logger.debug(
+ "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call",
+ tool_name,
+ str(tool_use.get("toolUseId")),
+ )
+
+ result: ToolResult = {
+ "toolUseId": str(tool_use.get("toolUseId")),
+ "status": "error",
+ "content": [{"text": f"Unknown tool: {tool_name}"}],
+ }
+ after_event = agent.hooks.invoke_callbacks(
+ AfterToolInvocationEvent(
+ agent=agent,
+ selected_tool=selected_tool,
+ tool_use=tool_use,
+ invocation_state=invocation_state,
+ result=result,
+ )
+ )
+ yield after_event.result
+ tool_results.append(after_event.result)
+ return
+
+ async for event in selected_tool.stream(tool_use, invocation_state, **kwargs):
+ yield event
+
+ result = cast(ToolResult, event)
+
+ after_event = agent.hooks.invoke_callbacks(
+ AfterToolInvocationEvent(
+ agent=agent,
+ selected_tool=selected_tool,
+ tool_use=tool_use,
+ invocation_state=invocation_state,
+ result=result,
+ )
+ )
+ yield after_event.result
+ tool_results.append(after_event.result)
+
+ except Exception as e:
+ logger.exception("tool_name=<%s> | failed to process tool", tool_name)
+ error_result: ToolResult = {
+ "toolUseId": str(tool_use.get("toolUseId")),
+ "status": "error",
+ "content": [{"text": f"Error: {str(e)}"}],
+ }
+ after_event = agent.hooks.invoke_callbacks(
+ AfterToolInvocationEvent(
+ agent=agent,
+ selected_tool=selected_tool,
+ tool_use=tool_use,
+ invocation_state=invocation_state,
+ result=error_result,
+ exception=e,
+ )
+ )
+ yield after_event.result
+ tool_results.append(after_event.result)
+
+ @staticmethod
+ async def _stream_with_trace(
+ agent: "Agent",
+ tool_use: ToolUse,
+ tool_results: list[ToolResult],
+ cycle_trace: Trace,
+ cycle_span: Any,
+ invocation_state: dict[str, Any],
+ **kwargs: Any,
+ ) -> ToolGenerator:
+ """Execute tool with tracing and metrics collection.
+
+ Args:
+ agent: The agent for which the tool is being executed.
+ tool_use: Metadata and inputs for the tool to be executed.
+ tool_results: List of tool results from each tool execution.
+ cycle_trace: Trace object for the current event loop cycle.
+ cycle_span: Span object for tracing the cycle.
+ invocation_state: Context for the tool invocation.
+ **kwargs: Additional keyword arguments for future extensibility.
+
+ Yields:
+ Tool events with the last being the tool result.
+ """
+ tool_name = tool_use["name"]
+
+ tracer = get_tracer()
+
+ tool_call_span = tracer.start_tool_call_span(tool_use, cycle_span)
+ tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
+ tool_start_time = time.time()
+
+ with trace_api.use_span(tool_call_span):
+ async for event in ToolExecutor._stream(agent, tool_use, tool_results, invocation_state, **kwargs):
+ yield event
+
+ result = cast(ToolResult, event)
+
+ tool_success = result.get("status") == "success"
+ tool_duration = time.time() - tool_start_time
+ message = Message(role="user", content=[{"toolResult": result}])
+ agent.event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message)
+ cycle_trace.add_child(tool_trace)
+
+ tracer.end_tool_call_span(tool_call_span, result)
+
+ @abc.abstractmethod
+ # pragma: no cover
+ def _execute(
+ self,
+ agent: "Agent",
+ tool_uses: list[ToolUse],
+ tool_results: list[ToolResult],
+ cycle_trace: Trace,
+ cycle_span: Any,
+ invocation_state: dict[str, Any],
+ ) -> ToolGenerator:
+ """Execute the given tools according to this executor's strategy.
+
+ Args:
+ agent: The agent for which tools are being executed.
+ tool_uses: Metadata and inputs for the tools to be executed.
+ tool_results: List of tool results from each tool execution.
+ cycle_trace: Trace object for the current event loop cycle.
+ cycle_span: Span object for tracing the cycle.
+ invocation_state: Context for the tool invocation.
+
+ Yields:
+ Events from the tool execution stream.
+ """
+ pass
diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py
new file mode 100644
index 000000000..7d5dd7fe7
--- /dev/null
+++ b/src/strands/tools/executors/concurrent.py
@@ -0,0 +1,113 @@
+"""Concurrent tool executor implementation."""
+
+import asyncio
+from typing import TYPE_CHECKING, Any
+
+from typing_extensions import override
+
+from ...telemetry.metrics import Trace
+from ...types.tools import ToolGenerator, ToolResult, ToolUse
+from ._executor import ToolExecutor
+
+if TYPE_CHECKING: # pragma: no cover
+ from ...agent import Agent
+
+
+class ConcurrentToolExecutor(ToolExecutor):
+ """Concurrent tool executor."""
+
+ @override
+ async def _execute(
+ self,
+ agent: "Agent",
+ tool_uses: list[ToolUse],
+ tool_results: list[ToolResult],
+ cycle_trace: Trace,
+ cycle_span: Any,
+ invocation_state: dict[str, Any],
+ ) -> ToolGenerator:
+ """Execute tools concurrently.
+
+ Args:
+ agent: The agent for which tools are being executed.
+ tool_uses: Metadata and inputs for the tools to be executed.
+ tool_results: List of tool results from each tool execution.
+ cycle_trace: Trace object for the current event loop cycle.
+ cycle_span: Span object for tracing the cycle.
+ invocation_state: Context for the tool invocation.
+
+ Yields:
+ Events from the tool execution stream.
+ """
+ task_queue: asyncio.Queue[tuple[int, Any]] = asyncio.Queue()
+ task_events = [asyncio.Event() for _ in tool_uses]
+ stop_event = object()
+
+ tasks = [
+ asyncio.create_task(
+ self._task(
+ agent,
+ tool_use,
+ tool_results,
+ cycle_trace,
+ cycle_span,
+ invocation_state,
+ task_id,
+ task_queue,
+ task_events[task_id],
+ stop_event,
+ )
+ )
+ for task_id, tool_use in enumerate(tool_uses)
+ ]
+
+ task_count = len(tasks)
+ while task_count:
+ task_id, event = await task_queue.get()
+ if event is stop_event:
+ task_count -= 1
+ continue
+
+ yield event
+ task_events[task_id].set()
+
+ asyncio.gather(*tasks)
+
+ async def _task(
+ self,
+ agent: "Agent",
+ tool_use: ToolUse,
+ tool_results: list[ToolResult],
+ cycle_trace: Trace,
+ cycle_span: Any,
+ invocation_state: dict[str, Any],
+ task_id: int,
+ task_queue: asyncio.Queue,
+ task_event: asyncio.Event,
+ stop_event: object,
+ ) -> None:
+ """Execute a single tool and put results in the task queue.
+
+ Args:
+ agent: The agent executing the tool.
+ tool_use: Tool use metadata and inputs.
+ tool_results: List of tool results from each tool execution.
+ cycle_trace: Trace object for the current event loop cycle.
+ cycle_span: Span object for tracing the cycle.
+ invocation_state: Context for tool execution.
+ task_id: Unique identifier for this task.
+ task_queue: Queue to put tool events into.
+ task_event: Event to signal when task can continue.
+ stop_event: Sentinel object to signal task completion.
+ """
+ try:
+ events = ToolExecutor._stream_with_trace(
+ agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state
+ )
+ async for event in events:
+ task_queue.put_nowait((task_id, event))
+ await task_event.wait()
+ task_event.clear()
+
+ finally:
+ task_queue.put_nowait((task_id, stop_event))
diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py
new file mode 100644
index 000000000..55b26f6d3
--- /dev/null
+++ b/src/strands/tools/executors/sequential.py
@@ -0,0 +1,46 @@
+"""Sequential tool executor implementation."""
+
+from typing import TYPE_CHECKING, Any
+
+from typing_extensions import override
+
+from ...telemetry.metrics import Trace
+from ...types.tools import ToolGenerator, ToolResult, ToolUse
+from ._executor import ToolExecutor
+
+if TYPE_CHECKING: # pragma: no cover
+ from ...agent import Agent
+
+
+class SequentialToolExecutor(ToolExecutor):
+ """Sequential tool executor."""
+
+ @override
+ async def _execute(
+ self,
+ agent: "Agent",
+ tool_uses: list[ToolUse],
+ tool_results: list[ToolResult],
+ cycle_trace: Trace,
+ cycle_span: Any,
+ invocation_state: dict[str, Any],
+ ) -> ToolGenerator:
+ """Execute tools sequentially.
+
+ Args:
+ agent: The agent for which tools are being executed.
+ tool_uses: Metadata and inputs for the tools to be executed.
+ tool_results: List of tool results from each tool execution.
+ cycle_trace: Trace object for the current event loop cycle.
+ cycle_span: Span object for tracing the cycle.
+ invocation_state: Context for the tool invocation.
+
+ Yields:
+ Events from the tool execution stream.
+ """
+ for tool_use in tool_uses:
+ events = ToolExecutor._stream_with_trace(
+ agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state
+ )
+ async for event in events:
+ yield event
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index 7e769c6d7..279e2a06e 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -73,12 +73,6 @@ def mock_event_loop_cycle():
yield mock
-@pytest.fixture
-def mock_run_tool():
- with unittest.mock.patch("strands.agent.agent.run_tool") as mock:
- yield mock
-
-
@pytest.fixture
def tool_registry():
return strands.tools.registry.ToolRegistry()
@@ -888,9 +882,7 @@ def test_agent_init_with_no_model_or_model_id():
assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID
-def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, mock_run_tool, agenerator):
- mock_run_tool.return_value = agenerator([{}])
-
+def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, agenerator):
@strands.tools.tool(name="system_prompter")
def function(system_prompt: str) -> str:
return system_prompt
@@ -899,22 +891,12 @@ def function(system_prompt: str) -> str:
mock_randint.return_value = 1
- agent.tool.system_prompter(system_prompt="tool prompt")
-
- mock_run_tool.assert_called_with(
- agent,
- {
- "toolUseId": "tooluse_system_prompter_1",
- "name": "system_prompter",
- "input": {"system_prompt": "tool prompt"},
- },
- {"system_prompt": "tool prompt"},
- )
-
+ tru_result = agent.tool.system_prompter(system_prompt="tool prompt")
+ exp_result = {"toolUseId": "tooluse_system_prompter_1", "status": "success", "content": [{"text": "tool prompt"}]}
+ assert tru_result == exp_result
-def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint, mock_run_tool, agenerator):
- mock_run_tool.return_value = agenerator([{}])
+def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint, agenerator):
tool_name = "system-prompter"
@strands.tools.tool(name=tool_name)
@@ -925,19 +907,9 @@ def function(system_prompt: str) -> str:
mock_randint.return_value = 1
- agent.tool.system_prompter(system_prompt="tool prompt")
-
- # Verify the correct tool was invoked
- assert mock_run_tool.call_count == 1
- tru_tool_use = mock_run_tool.call_args.args[1]
- exp_tool_use = {
- # Note that the tool-use uses the "python safe" name
- "toolUseId": "tooluse_system_prompter_1",
- # But the name of the tool is the one in the registry
- "name": tool_name,
- "input": {"system_prompt": "tool prompt"},
- }
- assert tru_tool_use == exp_tool_use
+ tru_result = agent.tool.system_prompter(system_prompt="tool prompt")
+ exp_result = {"toolUseId": "tooluse_system_prompter_1", "status": "success", "content": [{"text": "tool prompt"}]}
+ assert tru_result == exp_result
def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint):
diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py
index 191ab51ba..c76514ac8 100644
--- a/tests/strands/event_loop/test_event_loop.py
+++ b/tests/strands/event_loop/test_event_loop.py
@@ -1,23 +1,20 @@
import concurrent
import unittest.mock
-from unittest.mock import ANY, MagicMock, call, patch
+from unittest.mock import MagicMock, call, patch
import pytest
import strands
import strands.telemetry
-from strands.event_loop.event_loop import run_tool
from strands.experimental.hooks import (
AfterModelInvocationEvent,
AfterToolInvocationEvent,
BeforeModelInvocationEvent,
BeforeToolInvocationEvent,
)
-from strands.hooks import (
- HookProvider,
- HookRegistry,
-)
+from strands.hooks import HookRegistry
from strands.telemetry.metrics import EventLoopMetrics
+from strands.tools.executors import SequentialToolExecutor
from strands.tools.registry import ToolRegistry
from strands.types.exceptions import (
ContextWindowOverflowException,
@@ -131,7 +128,12 @@ def hook_provider(hook_registry):
@pytest.fixture
-def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry):
+def tool_executor():
+ return SequentialToolExecutor()
+
+
+@pytest.fixture
+def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry, tool_executor):
mock = unittest.mock.Mock(name="agent")
mock.config.cache_points = []
mock.model = model
@@ -141,6 +143,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis
mock.thread_pool = thread_pool
mock.event_loop_metrics = EventLoopMetrics()
mock.hooks = hook_registry
+ mock.tool_executor = tool_executor
return mock
@@ -812,260 +815,6 @@ async def test_prepare_next_cycle_in_tool_execution(agent, model, tool_stream, a
)
-@pytest.mark.asyncio
-async def test_run_tool(agent, tool, alist):
- process = run_tool(
- agent,
- tool_use={"toolUseId": "tool_use_id", "name": tool.tool_name, "input": {"random_string": "a_string"}},
- invocation_state={},
- )
-
- tru_result = (await alist(process))[-1]
- exp_result = {"toolUseId": "tool_use_id", "status": "success", "content": [{"text": "a_string"}]}
-
- assert tru_result == exp_result
-
-
-@pytest.mark.asyncio
-async def test_run_tool_missing_tool(agent, alist):
- process = run_tool(
- agent,
- tool_use={"toolUseId": "missing", "name": "missing", "input": {}},
- invocation_state={},
- )
-
- tru_events = await alist(process)
- exp_events = [
- {
- "toolUseId": "missing",
- "status": "error",
- "content": [{"text": "Unknown tool: missing"}],
- },
- ]
-
- assert tru_events == exp_events
-
-
-@pytest.mark.asyncio
-async def test_run_tool_hooks(agent, hook_provider, tool_times_2, alist):
- """Test that the correct hooks are emitted."""
-
- process = run_tool(
- agent=agent,
- tool_use={"toolUseId": "test", "name": tool_times_2.tool_name, "input": {"x": 5}},
- invocation_state={},
- )
- await alist(process)
-
- assert len(hook_provider.events_received) == 2
-
- assert hook_provider.events_received[0] == BeforeToolInvocationEvent(
- agent=agent,
- selected_tool=tool_times_2,
- tool_use={"input": {"x": 5}, "name": "multiply_by_2", "toolUseId": "test"},
- invocation_state=ANY,
- )
-
- assert hook_provider.events_received[1] == AfterToolInvocationEvent(
- agent=agent,
- selected_tool=tool_times_2,
- exception=None,
- tool_use={"toolUseId": "test", "name": tool_times_2.tool_name, "input": {"x": 5}},
- result={"toolUseId": "test", "status": "success", "content": [{"text": "10"}]},
- invocation_state=ANY,
- )
-
-
-@pytest.mark.asyncio
-async def test_run_tool_hooks_on_missing_tool(agent, hook_provider, alist):
- """Test that AfterToolInvocation hook is invoked even when tool throws exception."""
- process = run_tool(
- agent=agent,
- tool_use={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}},
- invocation_state={},
- )
- await alist(process)
-
- assert len(hook_provider.events_received) == 2
-
- assert hook_provider.events_received[0] == BeforeToolInvocationEvent(
- agent=agent,
- selected_tool=None,
- tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"},
- invocation_state=ANY,
- )
-
- assert hook_provider.events_received[1] == AfterToolInvocationEvent(
- agent=agent,
- selected_tool=None,
- tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"},
- invocation_state=ANY,
- result={"content": [{"text": "Unknown tool: missing_tool"}], "status": "error", "toolUseId": "test"},
- exception=None,
- )
-
-
-@pytest.mark.asyncio
-async def test_run_tool_hook_after_tool_invocation_on_exception(agent, tool_registry, hook_provider, alist):
- """Test that AfterToolInvocation hook is invoked even when tool throws exception."""
- error = ValueError("Tool failed")
-
- failing_tool = MagicMock()
- failing_tool.tool_name = "failing_tool"
-
- failing_tool.stream.side_effect = error
-
- tool_registry.register_tool(failing_tool)
-
- process = run_tool(
- agent=agent,
- tool_use={"toolUseId": "test", "name": "failing_tool", "input": {"x": 5}},
- invocation_state={},
- )
- await alist(process)
-
- assert hook_provider.events_received[1] == AfterToolInvocationEvent(
- agent=agent,
- selected_tool=failing_tool,
- tool_use={"input": {"x": 5}, "name": "failing_tool", "toolUseId": "test"},
- invocation_state=ANY,
- result={"content": [{"text": "Error: Tool failed"}], "status": "error", "toolUseId": "test"},
- exception=error,
- )
-
-
-@pytest.mark.asyncio
-async def test_run_tool_hook_before_tool_invocation_updates(agent, tool_times_5, hook_registry, hook_provider, alist):
- """Test that modifying properties on BeforeToolInvocation takes effect."""
-
- updated_tool_use = {"toolUseId": "modified", "name": "replacement_tool", "input": {"x": 3}}
-
- def modify_hook(event: BeforeToolInvocationEvent):
- # Modify selected_tool to use replacement_tool
- event.selected_tool = tool_times_5
- # Modify tool_use to change toolUseId
- event.tool_use = updated_tool_use
-
- hook_registry.add_callback(BeforeToolInvocationEvent, modify_hook)
-
- process = run_tool(
- agent=agent,
- tool_use={"toolUseId": "original", "name": "original_tool", "input": {"x": 1}},
- invocation_state={},
- )
- result = (await alist(process))[-1]
-
- # Should use replacement_tool (5 * 3 = 15) instead of original_tool (1 * 2 = 2)
- assert result == {"toolUseId": "modified", "status": "success", "content": [{"text": "15"}]}
-
- assert hook_provider.events_received[1] == AfterToolInvocationEvent(
- agent=agent,
- selected_tool=tool_times_5,
- tool_use=updated_tool_use,
- invocation_state=ANY,
- result={"content": [{"text": "15"}], "status": "success", "toolUseId": "modified"},
- exception=None,
- )
-
-
-@pytest.mark.asyncio
-async def test_run_tool_hook_after_tool_invocation_updates(agent, tool_times_2, hook_registry, alist):
- """Test that modifying properties on AfterToolInvocation takes effect."""
-
- updated_result = {"toolUseId": "modified", "status": "success", "content": [{"text": "modified_result"}]}
-
- def modify_hook(event: AfterToolInvocationEvent):
- # Modify result to change the output
- event.result = updated_result
-
- hook_registry.add_callback(AfterToolInvocationEvent, modify_hook)
-
- process = run_tool(
- agent=agent,
- tool_use={"toolUseId": "test", "name": tool_times_2.tool_name, "input": {"x": 5}},
- invocation_state={},
- )
-
- result = (await alist(process))[-1]
- assert result == updated_result
-
-
-@pytest.mark.asyncio
-async def test_run_tool_hook_after_tool_invocation_updates_with_missing_tool(agent, hook_registry, alist):
- """Test that modifying properties on AfterToolInvocation takes effect."""
-
- updated_result = {"toolUseId": "modified", "status": "success", "content": [{"text": "modified_result"}]}
-
- def modify_hook(event: AfterToolInvocationEvent):
- # Modify result to change the output
- event.result = updated_result
-
- hook_registry.add_callback(AfterToolInvocationEvent, modify_hook)
-
- process = run_tool(
- agent=agent,
- tool_use={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}},
- invocation_state={},
- )
-
- result = (await alist(process))[-1]
- assert result == updated_result
-
-
-@pytest.mark.asyncio
-async def test_run_tool_hook_update_result_with_missing_tool(agent, tool_registry, hook_registry, alist):
- """Test that modifying properties on AfterToolInvocation takes effect."""
-
- @strands.tool
- def test_quota():
- return "9"
-
- tool_registry.register_tool(test_quota)
-
- class ExampleProvider(HookProvider):
- def register_hooks(self, registry: "HookRegistry") -> None:
- registry.add_callback(BeforeToolInvocationEvent, self.before_tool_call)
- registry.add_callback(AfterToolInvocationEvent, self.after_tool_call)
-
- def before_tool_call(self, event: BeforeToolInvocationEvent):
- if event.tool_use.get("name") == "test_quota":
- event.selected_tool = None
-
- def after_tool_call(self, event: AfterToolInvocationEvent):
- if event.tool_use.get("name") == "test_quota":
- event.result = {
- "status": "error",
- "toolUseId": "test",
- "content": [{"text": "This tool has been used too many times!"}],
- }
-
- hook_registry.add_hook(ExampleProvider())
-
- with patch.object(strands.event_loop.event_loop, "logger") as mock_logger:
- process = run_tool(
- agent=agent,
- tool_use={"toolUseId": "test", "name": "test_quota", "input": {"x": 5}},
- invocation_state={},
- )
-
- result = (await alist(process))[-1]
-
- assert result == {
- "status": "error",
- "toolUseId": "test",
- "content": [{"text": "This tool has been used too many times!"}],
- }
-
- assert mock_logger.debug.call_args_list == [
- call("tool_use=<%s> | streaming", {"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}),
- call(
- "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call",
- "test_quota",
- "test",
- ),
- ]
-
-
@pytest.mark.asyncio
async def test_event_loop_cycle_exception_model_hooks(mock_time, agent, model, agenerator, alist, hook_provider):
"""Test that model hooks are correctly emitted even when throttled."""
diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py
new file mode 100644
index 000000000..1576b7578
--- /dev/null
+++ b/tests/strands/tools/executors/conftest.py
@@ -0,0 +1,116 @@
+import threading
+import unittest.mock
+
+import pytest
+
+import strands
+from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
+from strands.hooks import HookRegistry
+from strands.tools.registry import ToolRegistry
+
+
+@pytest.fixture
+def hook_events():
+ return []
+
+
+@pytest.fixture
+def tool_hook(hook_events):
+ def callback(event):
+ hook_events.append(event)
+ return event
+
+ return callback
+
+
+@pytest.fixture
+def hook_registry(tool_hook):
+ registry = HookRegistry()
+ registry.add_callback(BeforeToolInvocationEvent, tool_hook)
+ registry.add_callback(AfterToolInvocationEvent, tool_hook)
+ return registry
+
+
+@pytest.fixture
+def tool_events():
+ return []
+
+
+@pytest.fixture
+def weather_tool():
+ @strands.tool(name="weather_tool")
+ def func():
+ return "sunny"
+
+ return func
+
+
+@pytest.fixture
+def temperature_tool():
+ @strands.tool(name="temperature_tool")
+ def func():
+ return "75F"
+
+ return func
+
+
+@pytest.fixture
+def exception_tool():
+ @strands.tool(name="exception_tool")
+ def func():
+ pass
+
+ async def mock_stream(_tool_use, _invocation_state):
+ raise RuntimeError("Tool error")
+ yield # make generator
+
+ func.stream = mock_stream
+ return func
+
+
+@pytest.fixture
+def thread_tool(tool_events):
+ @strands.tool(name="thread_tool")
+ def func():
+ tool_events.append({"thread_name": threading.current_thread().name})
+ return "threaded"
+
+ return func
+
+
+@pytest.fixture
+def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool):
+ registry = ToolRegistry()
+ registry.register_tool(weather_tool)
+ registry.register_tool(temperature_tool)
+ registry.register_tool(exception_tool)
+ registry.register_tool(thread_tool)
+ return registry
+
+
+@pytest.fixture
+def agent(tool_registry, hook_registry):
+ mock_agent = unittest.mock.Mock()
+ mock_agent.tool_registry = tool_registry
+ mock_agent.hooks = hook_registry
+ return mock_agent
+
+
+@pytest.fixture
+def tool_results():
+ return []
+
+
+@pytest.fixture
+def cycle_trace():
+ return unittest.mock.Mock()
+
+
+@pytest.fixture
+def cycle_span():
+ return unittest.mock.Mock()
+
+
+@pytest.fixture
+def invocation_state():
+ return {}
diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py
new file mode 100644
index 000000000..7e0d6c2df
--- /dev/null
+++ b/tests/strands/tools/executors/test_concurrent.py
@@ -0,0 +1,32 @@
+import pytest
+
+from strands.tools.executors import ConcurrentToolExecutor
+
+
+@pytest.fixture
+def executor():
+ return ConcurrentToolExecutor()
+
+
+@pytest.mark.asyncio
+async def test_concurrent_executor_execute(
+ executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist
+):
+ tool_uses = [
+ {"name": "weather_tool", "toolUseId": "1", "input": {}},
+ {"name": "temperature_tool", "toolUseId": "2", "input": {}},
+ ]
+ stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state)
+
+ tru_events = sorted(await alist(stream), key=lambda event: event.get("toolUseId"))
+ exp_events = [
+ {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
+ {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
+ {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]},
+ {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]},
+ ]
+ assert tru_events == exp_events
+
+ tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId"))
+ exp_results = [exp_events[1], exp_events[3]]
+ assert tru_results == exp_results
diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py
new file mode 100644
index 000000000..edbad3939
--- /dev/null
+++ b/tests/strands/tools/executors/test_executor.py
@@ -0,0 +1,144 @@
+import unittest.mock
+
+import pytest
+
+import strands
+from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
+from strands.telemetry.metrics import Trace
+from strands.tools.executors._executor import ToolExecutor
+
+
+@pytest.fixture
+def executor_cls():
+ class ClsExecutor(ToolExecutor):
+ def _execute(self, _agent, _tool_uses, _tool_results, _invocation_state):
+ raise NotImplementedError
+
+ return ClsExecutor
+
+
+@pytest.fixture
+def executor(executor_cls):
+ return executor_cls()
+
+
+@pytest.fixture
+def tracer():
+ with unittest.mock.patch.object(strands.tools.executors._executor, "get_tracer") as mock_get_tracer:
+ yield mock_get_tracer.return_value
+
+
+@pytest.mark.asyncio
+async def test_executor_stream_yields_result(
+ executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist
+):
+ tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}}
+ stream = executor._stream(agent, tool_use, tool_results, invocation_state)
+
+ tru_events = await alist(stream)
+ exp_events = [
+ {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
+ {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
+ ]
+ assert tru_events == exp_events
+
+ tru_results = tool_results
+ exp_results = [exp_events[-1]]
+ assert tru_results == exp_results
+
+ tru_hook_events = hook_events
+ exp_hook_events = [
+ BeforeToolInvocationEvent(
+ agent=agent,
+ selected_tool=weather_tool,
+ tool_use=tool_use,
+ invocation_state=invocation_state,
+ ),
+ AfterToolInvocationEvent(
+ agent=agent,
+ selected_tool=weather_tool,
+ tool_use=tool_use,
+ invocation_state=invocation_state,
+ result=exp_results[0],
+ ),
+ ]
+ assert tru_hook_events == exp_hook_events
+
+
+@pytest.mark.asyncio
+async def test_executor_stream_yields_tool_error(
+ executor, agent, tool_results, invocation_state, hook_events, exception_tool, alist
+):
+ tool_use = {"name": "exception_tool", "toolUseId": "1", "input": {}}
+ stream = executor._stream(agent, tool_use, tool_results, invocation_state)
+
+ tru_events = await alist(stream)
+ exp_events = [{"toolUseId": "1", "status": "error", "content": [{"text": "Error: Tool error"}]}]
+ assert tru_events == exp_events
+
+ tru_results = tool_results
+ exp_results = [exp_events[-1]]
+ assert tru_results == exp_results
+
+ tru_hook_after_event = hook_events[-1]
+ exp_hook_after_event = AfterToolInvocationEvent(
+ agent=agent,
+ selected_tool=exception_tool,
+ tool_use=tool_use,
+ invocation_state=invocation_state,
+ result=exp_results[0],
+ exception=unittest.mock.ANY,
+ )
+ assert tru_hook_after_event == exp_hook_after_event
+
+
+@pytest.mark.asyncio
+async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results, invocation_state, hook_events, alist):
+ tool_use = {"name": "unknown_tool", "toolUseId": "1", "input": {}}
+ stream = executor._stream(agent, tool_use, tool_results, invocation_state)
+
+ tru_events = await alist(stream)
+ exp_events = [{"toolUseId": "1", "status": "error", "content": [{"text": "Unknown tool: unknown_tool"}]}]
+ assert tru_events == exp_events
+
+ tru_results = tool_results
+ exp_results = [exp_events[-1]]
+ assert tru_results == exp_results
+
+ tru_hook_after_event = hook_events[-1]
+ exp_hook_after_event = AfterToolInvocationEvent(
+ agent=agent,
+ selected_tool=None,
+ tool_use=tool_use,
+ invocation_state=invocation_state,
+ result=exp_results[0],
+ )
+ assert tru_hook_after_event == exp_hook_after_event
+
+
+@pytest.mark.asyncio
+async def test_executor_stream_with_trace(
+ executor, tracer, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist
+):
+ tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}}
+ stream = executor._stream_with_trace(agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state)
+
+ tru_events = await alist(stream)
+ exp_events = [
+ {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
+ {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
+ ]
+ assert tru_events == exp_events
+
+ tru_results = tool_results
+ exp_results = [exp_events[-1]]
+ assert tru_results == exp_results
+
+ tracer.start_tool_call_span.assert_called_once_with(tool_use, cycle_span)
+ tracer.end_tool_call_span.assert_called_once_with(
+ tracer.start_tool_call_span.return_value,
+ {"content": [{"text": "sunny"}], "status": "success", "toolUseId": "1"},
+ )
+
+ cycle_trace.add_child.assert_called_once()
+ assert isinstance(cycle_trace.add_child.call_args[0][0], Trace)
diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py
new file mode 100644
index 000000000..d9b32c129
--- /dev/null
+++ b/tests/strands/tools/executors/test_sequential.py
@@ -0,0 +1,32 @@
+import pytest
+
+from strands.tools.executors import SequentialToolExecutor
+
+
+@pytest.fixture
+def executor():
+ return SequentialToolExecutor()
+
+
+@pytest.mark.asyncio
+async def test_sequential_executor_execute(
+ executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist
+):
+ tool_uses = [
+ {"name": "weather_tool", "toolUseId": "1", "input": {}},
+ {"name": "temperature_tool", "toolUseId": "2", "input": {}},
+ ]
+ stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state)
+
+ tru_events = await alist(stream)
+ exp_events = [
+ {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
+ {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
+ {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]},
+ {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]},
+ ]
+ assert tru_events == exp_events
+
+ tru_results = tool_results
+ exp_results = [exp_events[1], exp_events[2]]
+ assert tru_results == exp_results
diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py
deleted file mode 100644
index 04d4ea657..000000000
--- a/tests/strands/tools/test_executor.py
+++ /dev/null
@@ -1,440 +0,0 @@
-import unittest.mock
-import uuid
-
-import pytest
-
-import strands
-import strands.telemetry
-from strands.types.content import Message
-
-
-@pytest.fixture(autouse=True)
-def moto_autouse(moto_env):
- _ = moto_env
-
-
-@pytest.fixture
-def tool_handler(request):
- async def handler(tool_use):
- yield {"event": "abc"}
- yield {
- **params,
- "toolUseId": tool_use["toolUseId"],
- }
-
- params = {
- "content": [{"text": "test result"}],
- "status": "success",
- }
- if hasattr(request, "param"):
- params.update(request.param)
-
- return handler
-
-
-@pytest.fixture
-def tool_use():
- return {"toolUseId": "t1", "name": "test_tool", "input": {"key": "value"}}
-
-
-@pytest.fixture
-def tool_uses(request, tool_use):
- return request.param if hasattr(request, "param") else [tool_use]
-
-
-@pytest.fixture
-def mock_metrics_client():
- with unittest.mock.patch("strands.telemetry.MetricsClient") as mock_metrics_client:
- yield mock_metrics_client
-
-
-@pytest.fixture
-def event_loop_metrics():
- return strands.telemetry.metrics.EventLoopMetrics()
-
-
-@pytest.fixture
-def invalid_tool_use_ids(request):
- return request.param if hasattr(request, "param") else []
-
-
-@pytest.fixture
-def cycle_trace():
- with unittest.mock.patch.object(uuid, "uuid4", return_value="trace1"):
- return strands.telemetry.metrics.Trace(name="test trace", raw_name="raw_name")
-
-
-@pytest.mark.asyncio
-async def test_run_tools(
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- cycle_trace,
- alist,
-):
- tool_results = []
-
- stream = strands.tools.executor.run_tools(
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- tool_results,
- cycle_trace,
- )
-
- tru_events = await alist(stream)
- exp_events = [
- {"event": "abc"},
- {
- "content": [
- {
- "text": "test result",
- },
- ],
- "status": "success",
- "toolUseId": "t1",
- },
- ]
-
- tru_results = tool_results
- exp_results = [exp_events[-1]]
-
- assert tru_events == exp_events and tru_results == exp_results
-
-
-@pytest.mark.parametrize("invalid_tool_use_ids", [["t1"]], indirect=True)
-@pytest.mark.asyncio
-async def test_run_tools_invalid_tool(
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- cycle_trace,
- alist,
-):
- tool_results = []
-
- stream = strands.tools.executor.run_tools(
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- tool_results,
- cycle_trace,
- )
- await alist(stream)
-
- tru_results = tool_results
- exp_results = []
-
- assert tru_results == exp_results
-
-
-@pytest.mark.parametrize("tool_handler", [{"status": "failed"}], indirect=True)
-@pytest.mark.asyncio
-async def test_run_tools_failed_tool(
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- cycle_trace,
- alist,
-):
- tool_results = []
-
- stream = strands.tools.executor.run_tools(
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- tool_results,
- cycle_trace,
- )
- await alist(stream)
-
- tru_results = tool_results
- exp_results = [
- {
- "content": [
- {
- "text": "test result",
- },
- ],
- "status": "failed",
- "toolUseId": "t1",
- },
- ]
-
- assert tru_results == exp_results
-
-
-@pytest.mark.parametrize(
- ("tool_uses", "invalid_tool_use_ids"),
- [
- (
- [
- {
- "toolUseId": "t1",
- "name": "test_tool_success",
- "input": {"key": "value1"},
- },
- {
- "toolUseId": "t2",
- "name": "test_tool_invalid",
- "input": {"key": "value2"},
- },
- ],
- ["t2"],
- ),
- ],
- indirect=True,
-)
-@pytest.mark.asyncio
-async def test_run_tools_sequential(
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- cycle_trace,
- alist,
-):
- tool_results = []
-
- stream = strands.tools.executor.run_tools(
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- tool_results,
- cycle_trace,
- None, # tool_pool
- )
- await alist(stream)
-
- tru_results = tool_results
- exp_results = [
- {
- "content": [
- {
- "text": "test result",
- },
- ],
- "status": "success",
- "toolUseId": "t1",
- },
- ]
-
- assert tru_results == exp_results
-
-
-def test_validate_and_prepare_tools():
- message: Message = {
- "role": "assistant",
- "content": [
- {"text": "value"},
- {"toolUse": {"toolUseId": "t1", "name": "test_tool", "input": {"key": "value"}}},
- {"toolUse": {"toolUseId": "t2-invalid"}},
- ],
- }
-
- tool_uses = []
- tool_results = []
- invalid_tool_use_ids = []
-
- strands.tools.executor.validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
-
- tru_tool_uses, tru_tool_results, tru_invalid_tool_use_ids = tool_uses, tool_results, invalid_tool_use_ids
- exp_tool_uses = [
- {
- "input": {
- "key": "value",
- },
- "name": "test_tool",
- "toolUseId": "t1",
- },
- {
- "name": "INVALID_TOOL_NAME",
- "toolUseId": "t2-invalid",
- },
- ]
- exp_tool_results = [
- {
- "content": [
- {
- "text": "Error: tool name missing",
- },
- ],
- "status": "error",
- "toolUseId": "t2-invalid",
- },
- ]
- exp_invalid_tool_use_ids = ["t2-invalid"]
-
- assert tru_tool_uses == exp_tool_uses
- assert tru_tool_results == exp_tool_results
- assert tru_invalid_tool_use_ids == exp_invalid_tool_use_ids
-
-
-@unittest.mock.patch("strands.tools.executor.get_tracer")
-@pytest.mark.asyncio
-async def test_run_tools_creates_and_ends_span_on_success(
- mock_get_tracer,
- tool_handler,
- tool_uses,
- mock_metrics_client,
- event_loop_metrics,
- invalid_tool_use_ids,
- cycle_trace,
- alist,
-):
- """Test that run_tools creates and ends a span on successful execution."""
- # Setup mock tracer and span
- mock_tracer = unittest.mock.MagicMock()
- mock_span = unittest.mock.MagicMock()
- mock_tracer.start_tool_call_span.return_value = mock_span
- mock_get_tracer.return_value = mock_tracer
-
- # Setup mock parent span
- parent_span = unittest.mock.MagicMock()
-
- tool_results = []
-
- # Run the tool
- stream = strands.tools.executor.run_tools(
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- tool_results,
- cycle_trace,
- parent_span,
- )
- await alist(stream)
-
- # Verify span was created with the parent span
- mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span)
-
- # Verify span was ended with the tool result
- mock_tracer.end_tool_call_span.assert_called_once()
- args, _ = mock_tracer.end_tool_call_span.call_args
- assert args[0] == mock_span
- assert args[1]["status"] == "success"
- assert args[1]["content"][0]["text"] == "test result"
-
-
-@unittest.mock.patch("strands.tools.executor.get_tracer")
-@pytest.mark.parametrize("tool_handler", [{"status": "failed"}], indirect=True)
-@pytest.mark.asyncio
-async def test_run_tools_creates_and_ends_span_on_failure(
- mock_get_tracer,
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- cycle_trace,
- alist,
-):
- """Test that run_tools creates and ends a span on tool failure."""
- # Setup mock tracer and span
- mock_tracer = unittest.mock.MagicMock()
- mock_span = unittest.mock.MagicMock()
- mock_tracer.start_tool_call_span.return_value = mock_span
- mock_get_tracer.return_value = mock_tracer
-
- # Setup mock parent span
- parent_span = unittest.mock.MagicMock()
-
- tool_results = []
-
- # Run the tool
- stream = strands.tools.executor.run_tools(
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- tool_results,
- cycle_trace,
- parent_span,
- )
- await alist(stream)
-
- # Verify span was created with the parent span
- mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span)
-
- # Verify span was ended with the tool result
- mock_tracer.end_tool_call_span.assert_called_once()
- args, _ = mock_tracer.end_tool_call_span.call_args
- assert args[0] == mock_span
- assert args[1]["status"] == "failed"
-
-
-@unittest.mock.patch("strands.tools.executor.get_tracer")
-@pytest.mark.parametrize(
- ("tool_uses", "invalid_tool_use_ids"),
- [
- (
- [
- {
- "toolUseId": "t1",
- "name": "test_tool_success",
- "input": {"key": "value1"},
- },
- {
- "toolUseId": "t2",
- "name": "test_tool_also_success",
- "input": {"key": "value2"},
- },
- ],
- [],
- ),
- ],
- indirect=True,
-)
-@pytest.mark.asyncio
-async def test_run_tools_concurrent_execution_with_spans(
- mock_get_tracer,
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- cycle_trace,
- alist,
-):
- # Setup mock tracer and spans
- mock_tracer = unittest.mock.MagicMock()
- mock_span1 = unittest.mock.MagicMock()
- mock_span2 = unittest.mock.MagicMock()
- mock_tracer.start_tool_call_span.side_effect = [mock_span1, mock_span2]
- mock_get_tracer.return_value = mock_tracer
-
- # Setup mock parent span
- parent_span = unittest.mock.MagicMock()
-
- tool_results = []
-
- # Run the tools
- stream = strands.tools.executor.run_tools(
- tool_handler,
- tool_uses,
- event_loop_metrics,
- invalid_tool_use_ids,
- tool_results,
- cycle_trace,
- parent_span,
- )
- await alist(stream)
-
- # Verify spans were created for both tools
- assert mock_tracer.start_tool_call_span.call_count == 2
- mock_tracer.start_tool_call_span.assert_has_calls(
- [
- unittest.mock.call(tool_uses[0], parent_span),
- unittest.mock.call(tool_uses[1], parent_span),
- ],
- any_order=True,
- )
-
- # Verify spans were ended for both tools
- assert mock_tracer.end_tool_call_span.call_count == 2
diff --git a/tests/strands/tools/test_validator.py b/tests/strands/tools/test_validator.py
new file mode 100644
index 000000000..46e5e15f3
--- /dev/null
+++ b/tests/strands/tools/test_validator.py
@@ -0,0 +1,50 @@
+from strands.tools import _validator
+from strands.types.content import Message
+
+
+def test_validate_and_prepare_tools():
+ message: Message = {
+ "role": "assistant",
+ "content": [
+ {"text": "value"},
+ {"toolUse": {"toolUseId": "t1", "name": "test_tool", "input": {"key": "value"}}},
+ {"toolUse": {"toolUseId": "t2-invalid"}},
+ ],
+ }
+
+ tool_uses = []
+ tool_results = []
+ invalid_tool_use_ids = []
+
+ _validator.validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
+
+ tru_tool_uses, tru_tool_results, tru_invalid_tool_use_ids = tool_uses, tool_results, invalid_tool_use_ids
+ exp_tool_uses = [
+ {
+ "input": {
+ "key": "value",
+ },
+ "name": "test_tool",
+ "toolUseId": "t1",
+ },
+ {
+ "name": "INVALID_TOOL_NAME",
+ "toolUseId": "t2-invalid",
+ },
+ ]
+ exp_tool_results = [
+ {
+ "content": [
+ {
+ "text": "Error: tool name missing",
+ },
+ ],
+ "status": "error",
+ "toolUseId": "t2-invalid",
+ },
+ ]
+ exp_invalid_tool_use_ids = ["t2-invalid"]
+
+ assert tru_tool_uses == exp_tool_uses
+ assert tru_tool_results == exp_tool_results
+ assert tru_invalid_tool_use_ids == exp_invalid_tool_use_ids
diff --git a/tests_integ/tools/executors/test_concurrent.py b/tests_integ/tools/executors/test_concurrent.py
new file mode 100644
index 000000000..27dd468e0
--- /dev/null
+++ b/tests_integ/tools/executors/test_concurrent.py
@@ -0,0 +1,61 @@
+import asyncio
+
+import pytest
+
+import strands
+from strands import Agent
+from strands.tools.executors import ConcurrentToolExecutor
+
+
+@pytest.fixture
+def tool_executor():
+ return ConcurrentToolExecutor()
+
+
+@pytest.fixture
+def tool_events():
+ return []
+
+
+@pytest.fixture
+def time_tool(tool_events):
+ @strands.tool(name="time_tool")
+ async def func():
+ tool_events.append({"name": "time_tool", "event": "start"})
+ await asyncio.sleep(2)
+ tool_events.append({"name": "time_tool", "event": "end"})
+ return "12:00"
+
+ return func
+
+
+@pytest.fixture
+def weather_tool(tool_events):
+ @strands.tool(name="weather_tool")
+ async def func():
+ tool_events.append({"name": "weather_tool", "event": "start"})
+ await asyncio.sleep(1)
+ tool_events.append({"name": "weather_tool", "event": "end"})
+
+ return "sunny"
+
+ return func
+
+
+@pytest.fixture
+def agent(tool_executor, time_tool, weather_tool):
+ return Agent(tools=[time_tool, weather_tool], tool_executor=tool_executor)
+
+
+@pytest.mark.asyncio
+async def test_agent_invoke_async_tool_executor(agent, tool_events):
+ await agent.invoke_async("What is the time and weather in New York?")
+
+ tru_events = tool_events
+ exp_events = [
+ {"name": "time_tool", "event": "start"},
+ {"name": "weather_tool", "event": "start"},
+ {"name": "weather_tool", "event": "end"},
+ {"name": "time_tool", "event": "end"},
+ ]
+ assert tru_events == exp_events
diff --git a/tests_integ/tools/executors/test_sequential.py b/tests_integ/tools/executors/test_sequential.py
new file mode 100644
index 000000000..82fc51a59
--- /dev/null
+++ b/tests_integ/tools/executors/test_sequential.py
@@ -0,0 +1,61 @@
+import asyncio
+
+import pytest
+
+import strands
+from strands import Agent
+from strands.tools.executors import SequentialToolExecutor
+
+
+@pytest.fixture
+def tool_executor():
+ return SequentialToolExecutor()
+
+
+@pytest.fixture
+def tool_events():
+ return []
+
+
+@pytest.fixture
+def time_tool(tool_events):
+ @strands.tool(name="time_tool")
+ async def func():
+ tool_events.append({"name": "time_tool", "event": "start"})
+ await asyncio.sleep(2)
+ tool_events.append({"name": "time_tool", "event": "end"})
+ return "12:00"
+
+ return func
+
+
+@pytest.fixture
+def weather_tool(tool_events):
+ @strands.tool(name="weather_tool")
+ async def func():
+ tool_events.append({"name": "weather_tool", "event": "start"})
+ await asyncio.sleep(1)
+ tool_events.append({"name": "weather_tool", "event": "end"})
+
+ return "sunny"
+
+ return func
+
+
+@pytest.fixture
+def agent(tool_executor, time_tool, weather_tool):
+ return Agent(tools=[time_tool, weather_tool], tool_executor=tool_executor)
+
+
+@pytest.mark.asyncio
+async def test_agent_invoke_async_tool_executor(agent, tool_events):
+ await agent.invoke_async("What is the time and weather in New York?")
+
+ tru_events = tool_events
+ exp_events = [
+ {"name": "time_tool", "event": "start"},
+ {"name": "time_tool", "event": "end"},
+ {"name": "weather_tool", "event": "start"},
+ {"name": "weather_tool", "event": "end"},
+ ]
+ assert tru_events == exp_events
From dbe0fea146749f578bfd73dae22182d69df70a7e Mon Sep 17 00:00:00 2001
From: Nick Clegg
Date: Mon, 25 Aug 2025 11:06:43 -0400
Subject: [PATCH 050/104] feat: Add support for agent invoke with no input, or
Message input (#653)
---
src/strands/agent/agent.py | 122 ++++++++++++++++++-------
src/strands/telemetry/tracer.py | 17 ++--
tests/strands/agent/test_agent.py | 82 ++++++++++++++---
tests/strands/telemetry/test_tracer.py | 2 +-
4 files changed, 168 insertions(+), 55 deletions(-)
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index adc554bf4..654b8edce 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -361,14 +361,21 @@ def tool_names(self) -> list[str]:
all_tools = self.tool_registry.get_all_tools_config()
return list(all_tools.keys())
- def __call__(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult:
+ def __call__(self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.
- This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to
- the conversation history, processes it through the model, executes any tool calls, and returns the final result.
+ This method implements the conversational interface with multiple input patterns:
+ - String input: `agent("hello!")`
+ - ContentBlock list: `agent([{"text": "hello"}, {"image": {...}}])`
+ - Message list: `agent([{"role": "user", "content": [{"text": "hello"}]}])`
+ - No input: `agent()` - uses existing conversation history
Args:
- prompt: User input as text or list of ContentBlock objects for multi-modal content.
+ prompt: User input in various formats:
+ - str: Simple text input
+ - list[ContentBlock]: Multi-modal content blocks
+ - list[Message]: Complete messages with roles
+ - None: Use existing conversation history
**kwargs: Additional parameters to pass through the event loop.
Returns:
@@ -387,14 +394,23 @@ def execute() -> AgentResult:
future = executor.submit(execute)
return future.result()
- async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult:
+ async def invoke_async(
+ self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any
+ ) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.
- This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to
- the conversation history, processes it through the model, executes any tool calls, and returns the final result.
+ This method implements the conversational interface with multiple input patterns:
+ - String input: Simple text input
+ - ContentBlock list: Multi-modal content blocks
+ - Message list: Complete messages with roles
+ - No input: Use existing conversation history
Args:
- prompt: User input as text or list of ContentBlock objects for multi-modal content.
+ prompt: User input in various formats:
+ - str: Simple text input
+ - list[ContentBlock]: Multi-modal content blocks
+ - list[Message]: Complete messages with roles
+ - None: Use existing conversation history
**kwargs: Additional parameters to pass through the event loop.
Returns:
@@ -411,7 +427,7 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
return cast(AgentResult, event["result"])
- def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T:
+ def structured_output(self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None) -> T:
"""This method allows you to get structured output from the agent.
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
@@ -423,7 +439,11 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l
Args:
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
that the agent will use when responding.
- prompt: The prompt to use for the agent (will not be added to conversation history).
+ prompt: The prompt to use for the agent in various formats:
+ - str: Simple text input
+ - list[ContentBlock]: Multi-modal content blocks
+ - list[Message]: Complete messages with roles
+ - None: Use existing conversation history
Raises:
ValueError: If no conversation history or prompt is provided.
@@ -437,7 +457,7 @@ def execute() -> T:
return future.result()
async def structured_output_async(
- self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None
+ self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None
) -> T:
"""This method allows you to get structured output from the agent.
@@ -462,12 +482,8 @@ async def structured_output_async(
try:
if not self.messages and not prompt:
raise ValueError("No conversation history or prompt provided")
- # Create temporary messages array if prompt is provided
- if prompt:
- content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
- temp_messages = self.messages + [{"role": "user", "content": content}]
- else:
- temp_messages = self.messages
+
+ temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt)
structured_output_span.set_attributes(
{
@@ -499,16 +515,25 @@ async def structured_output_async(
finally:
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
- async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
+ async def stream_async(
+ self,
+ prompt: str | list[ContentBlock] | Messages | None = None,
+ **kwargs: Any,
+ ) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.
- This method provides an asynchronous interface for streaming agent events, allowing
- consumers to process stream events programmatically through an async iterator pattern
- rather than callback functions. This is particularly useful for web servers and other
- async environments.
+ This method provides an asynchronous interface for streaming agent events with multiple input patterns:
+ - String input: Simple text input
+ - ContentBlock list: Multi-modal content blocks
+ - Message list: Complete messages with roles
+ - No input: Use existing conversation history
Args:
- prompt: User input as text or list of ContentBlock objects for multi-modal content.
+ prompt: User input in various formats:
+ - str: Simple text input
+ - list[ContentBlock]: Multi-modal content blocks
+ - list[Message]: Complete messages with roles
+ - None: Use existing conversation history
**kwargs: Additional parameters to pass to the event loop.
Yields:
@@ -532,13 +557,15 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
"""
callback_handler = kwargs.get("callback_handler", self.callback_handler)
- content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
- message: Message = {"role": "user", "content": content}
+ # Process input and get message to add (if any)
+ messages = self._convert_prompt_to_messages(prompt)
+
+ self.trace_span = self._start_agent_trace_span(messages)
- self.trace_span = self._start_agent_trace_span(message)
with trace_api.use_span(self.trace_span):
try:
- events = self._run_loop(message, invocation_state=kwargs)
+ events = self._run_loop(messages, invocation_state=kwargs)
+
async for event in events:
if "callback" in event:
callback_handler(**event["callback"])
@@ -555,12 +582,12 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
raise
async def _run_loop(
- self, message: Message, invocation_state: dict[str, Any]
+ self, messages: Messages, invocation_state: dict[str, Any]
) -> AsyncGenerator[dict[str, Any], None]:
"""Execute the agent's event loop with the given message and parameters.
Args:
- message: The user message to add to the conversation.
+ messages: The input messages to add to the conversation.
invocation_state: Additional parameters to pass to the event loop.
Yields:
@@ -571,7 +598,8 @@ async def _run_loop(
try:
yield {"callback": {"init_event_loop": True, **invocation_state}}
- self._append_message(message)
+ for message in messages:
+ self._append_message(message)
# Execute the event loop cycle with retry logic for context limits
events = self._execute_event_loop_cycle(invocation_state)
@@ -629,6 +657,34 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A
async for event in events:
yield event
+ def _convert_prompt_to_messages(self, prompt: str | list[ContentBlock] | Messages | None) -> Messages:
+ messages: Messages | None = None
+ if prompt is not None:
+ if isinstance(prompt, str):
+ # String input - convert to user message
+ messages = [{"role": "user", "content": [{"text": prompt}]}]
+ elif isinstance(prompt, list):
+ if len(prompt) == 0:
+ # Empty list
+ messages = []
+ # Check if all item in input list are dictionaries
+ elif all(isinstance(item, dict) for item in prompt):
+ # Check if all items are messages
+ if all(all(key in item for key in Message.__annotations__.keys()) for item in prompt):
+ # Messages input - add all messages to conversation
+ messages = cast(Messages, prompt)
+
+ # Check if all items are content blocks
+ elif all(any(key in ContentBlock.__annotations__.keys() for key in item) for item in prompt):
+ # Treat as List[ContentBlock] input - convert to user message
+ # This allows invalid structures to be passed through to the model
+ messages = [{"role": "user", "content": cast(list[ContentBlock], prompt)}]
+ else:
+ messages = []
+ if messages is None:
+ raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.")
+ return messages
+
def _record_tool_execution(
self,
tool: ToolUse,
@@ -694,15 +750,15 @@ def _record_tool_execution(
self._append_message(tool_result_msg)
self._append_message(assistant_msg)
- def _start_agent_trace_span(self, message: Message) -> trace_api.Span:
+ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:
"""Starts a trace span for the agent.
Args:
- message: The user message.
+ messages: The input messages.
"""
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None
return self.tracer.start_agent_span(
- message=message,
+ messages=messages,
agent_name=self.name,
model_id=model_id,
tools=self.tool_names,
diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py
index 802865189..6b429393d 100644
--- a/src/strands/telemetry/tracer.py
+++ b/src/strands/telemetry/tracer.py
@@ -408,7 +408,7 @@ def end_event_loop_cycle_span(
def start_agent_span(
self,
- message: Message,
+ messages: Messages,
agent_name: str,
model_id: Optional[str] = None,
tools: Optional[list] = None,
@@ -418,7 +418,7 @@ def start_agent_span(
"""Start a new span for an agent invocation.
Args:
- message: The user message being sent to the agent.
+ messages: List of messages being sent to the agent.
agent_name: Name of the agent.
model_id: Optional model identifier.
tools: Optional list of tools being used.
@@ -451,13 +451,12 @@ def start_agent_span(
span = self._start_span(
f"invoke_agent {agent_name}", attributes=attributes, span_kind=trace_api.SpanKind.CLIENT
)
- self._add_event(
- span,
- "gen_ai.user.message",
- event_attributes={
- "content": serialize(message["content"]),
- },
- )
+ for message in messages:
+ self._add_event(
+ span,
+ f"gen_ai.{message['role']}.message",
+ {"content": serialize(message["content"])},
+ )
return span
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index 279e2a06e..01d8f977e 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -1332,12 +1332,12 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model
# Verify span was created
mock_tracer.start_agent_span.assert_called_once_with(
+ messages=[{"content": [{"text": "test prompt"}], "role": "user"}],
agent_name="Strands Agents",
- custom_trace_attributes=agent.trace_attributes,
- message={"content": [{"text": "test prompt"}], "role": "user"},
model_id=unittest.mock.ANY,
- system_prompt=agent.system_prompt,
tools=agent.tool_names,
+ system_prompt=agent.system_prompt,
+ custom_trace_attributes=agent.trace_attributes,
)
# Verify span was ended with the result
@@ -1366,12 +1366,12 @@ async def test_event_loop(*args, **kwargs):
# Verify span was created
mock_tracer.start_agent_span.assert_called_once_with(
- custom_trace_attributes=agent.trace_attributes,
+ messages=[{"content": [{"text": "test prompt"}], "role": "user"}],
agent_name="Strands Agents",
- message={"content": [{"text": "test prompt"}], "role": "user"},
model_id=unittest.mock.ANY,
- system_prompt=agent.system_prompt,
tools=agent.tool_names,
+ system_prompt=agent.system_prompt,
+ custom_trace_attributes=agent.trace_attributes,
)
expected_response = AgentResult(
@@ -1404,12 +1404,12 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod
# Verify span was created
mock_tracer.start_agent_span.assert_called_once_with(
- custom_trace_attributes=agent.trace_attributes,
+ messages=[{"content": [{"text": "test prompt"}], "role": "user"}],
agent_name="Strands Agents",
- message={"content": [{"text": "test prompt"}], "role": "user"},
model_id=unittest.mock.ANY,
- system_prompt=agent.system_prompt,
tools=agent.tool_names,
+ system_prompt=agent.system_prompt,
+ custom_trace_attributes=agent.trace_attributes,
)
# Verify span was ended with the exception
@@ -1440,12 +1440,12 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr
# Verify span was created
mock_tracer.start_agent_span.assert_called_once_with(
+ messages=[{"content": [{"text": "test prompt"}], "role": "user"}],
agent_name="Strands Agents",
- custom_trace_attributes=agent.trace_attributes,
- message={"content": [{"text": "test prompt"}], "role": "user"},
model_id=unittest.mock.ANY,
- system_prompt=agent.system_prompt,
tools=agent.tool_names,
+ system_prompt=agent.system_prompt,
+ custom_trace_attributes=agent.trace_attributes,
)
# Verify span was ended with the exception
@@ -1773,6 +1773,63 @@ def test_agent_tool_record_direct_tool_call_disabled_with_non_serializable(agent
assert len(agent.messages) == 0
+def test_agent_empty_invoke():
+ model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}])
+ agent = Agent(model=model, messages=[{"role": "user", "content": [{"text": "hello!"}]}])
+ result = agent()
+ assert str(result) == "hello!\n"
+ assert len(agent.messages) == 2
+
+
+def test_agent_empty_list_invoke():
+ model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}])
+ agent = Agent(model=model, messages=[{"role": "user", "content": [{"text": "hello!"}]}])
+ result = agent([])
+ assert str(result) == "hello!\n"
+ assert len(agent.messages) == 2
+
+
+def test_agent_with_assistant_role_message():
+ model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}])
+ agent = Agent(model=model)
+ assistant_message = [{"role": "assistant", "content": [{"text": "hello..."}]}]
+ result = agent(assistant_message)
+ assert str(result) == "world!\n"
+ assert len(agent.messages) == 2
+
+
+def test_agent_with_multiple_messages_on_invoke():
+ model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}])
+ agent = Agent(model=model)
+ input_messages = [
+ {"role": "user", "content": [{"text": "hello"}]},
+ {"role": "assistant", "content": [{"text": "..."}]},
+ ]
+ result = agent(input_messages)
+ assert str(result) == "world!\n"
+ assert len(agent.messages) == 3
+
+
+def test_agent_with_invalid_input():
+ model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}])
+ agent = Agent(model=model)
+ with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."):
+ agent({"invalid": "input"})
+
+
+def test_agent_with_invalid_input_list():
+ model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}])
+ agent = Agent(model=model)
+ with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."):
+ agent([{"invalid": "input"}])
+
+
+def test_agent_with_list_of_message_and_content_block():
+ model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}])
+ agent = Agent(model=model)
+ with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."):
+ agent([{"role": "user", "content": [{"text": "hello"}]}, {"text", "hello"}])
+
def test_agent_tool_call_parameter_filtering_integration(mock_randint):
"""Test that tool calls properly filter parameters in message recording."""
mock_randint.return_value = 42
@@ -1804,3 +1861,4 @@ def test_tool(action: str) -> str:
assert '"action": "test_value"' in tool_call_text
assert '"agent"' not in tool_call_text
assert '"extra_param"' not in tool_call_text
+
diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py
index dcfce1211..586911bef 100644
--- a/tests/strands/telemetry/test_tracer.py
+++ b/tests/strands/telemetry/test_tracer.py
@@ -369,7 +369,7 @@ def test_start_agent_span(mock_tracer):
span = tracer.start_agent_span(
custom_trace_attributes=custom_attrs,
agent_name="WeatherAgent",
- message={"content": content, "role": "user"},
+ messages=[{"content": content, "role": "user"}],
model_id=model_id,
tools=tools,
)
From b156ea68c824fdb968d4d986a835878b0bfc1b93 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Mon, 25 Aug 2025 11:06:54 -0400
Subject: [PATCH 051/104] ci: bump actions/checkout from 4 to 5 (#711)
Bumps [actions/checkout](https://github.com/actions/checkout) from 4 to 5.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v5)
---
updated-dependencies:
- dependency-name: actions/checkout
dependency-version: '5'
dependency-type: direct:production
update-type: version-update:semver-major
...
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
---
.github/workflows/integration-test.yml | 2 +-
.github/workflows/pypi-publish-on-release.yml | 2 +-
.github/workflows/test-lint.yml | 4 ++--
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml
index c347e3805..d410bb712 100644
--- a/.github/workflows/integration-test.yml
+++ b/.github/workflows/integration-test.yml
@@ -52,7 +52,7 @@ jobs:
aws-region: us-east-1
mask-aws-account-id: true
- name: Checkout head commit
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
with:
ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo
persist-credentials: false # Don't persist credentials for subsequent actions
diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml
index 8967c5524..e3c5385a7 100644
--- a/.github/workflows/pypi-publish-on-release.yml
+++ b/.github/workflows/pypi-publish-on-release.yml
@@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v5
with:
persist-credentials: false
diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml
index 35e0f5841..c0ed4faca 100644
--- a/.github/workflows/test-lint.yml
+++ b/.github/workflows/test-lint.yml
@@ -51,7 +51,7 @@ jobs:
LOG_LEVEL: DEBUG
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
with:
ref: ${{ inputs.ref }} # Explicitly define which commit to check out
persist-credentials: false # Don't persist credentials for subsequent actions
@@ -73,7 +73,7 @@ jobs:
contents: read
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
with:
ref: ${{ inputs.ref }}
persist-credentials: false
From 0283169c7a3e424494e6260d163324f75eeeb8f7 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Mon, 25 Aug 2025 11:07:08 -0400
Subject: [PATCH 052/104] ci: bump actions/download-artifact from 4 to 5 (#712)
Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 4 to 5.
- [Release notes](https://github.com/actions/download-artifact/releases)
- [Commits](https://github.com/actions/download-artifact/compare/v4...v5)
---
updated-dependencies:
- dependency-name: actions/download-artifact
dependency-version: '5'
dependency-type: direct:production
update-type: version-update:semver-major
...
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
---
.github/workflows/pypi-publish-on-release.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml
index e3c5385a7..c2420d747 100644
--- a/.github/workflows/pypi-publish-on-release.yml
+++ b/.github/workflows/pypi-publish-on-release.yml
@@ -74,7 +74,7 @@ jobs:
steps:
- name: Download all the dists
- uses: actions/download-artifact@v4
+ uses: actions/download-artifact@v5
with:
name: python-package-distributions
path: dist/
From e5e308ff794d02eca035a96e148478ed14747ea9 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Mon, 25 Aug 2025 11:25:31 -0400
Subject: [PATCH 053/104] ci: update pytest-cov requirement from <5.0.0,>=4.1.0
to >=4.1.0,<7.0.0 (#705)
---
pyproject.toml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 32de94aa6..8a95ba04c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -57,8 +57,8 @@ dev = [
"mypy>=1.15.0,<2.0.0",
"pre-commit>=3.2.0,<4.4.0",
"pytest>=8.0.0,<9.0.0",
+ "pytest-cov>=6.0.0,<7.0.0",
"pytest-asyncio>=1.0.0,<1.2.0",
- "pytest-cov>=4.1.0,<5.0.0",
"pytest-xdist>=3.0.0,<4.0.0",
"ruff>=0.12.0,<0.13.0",
]
@@ -143,8 +143,8 @@ features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mis
extra-dependencies = [
"moto>=5.1.0,<6.0.0",
"pytest>=8.0.0,<9.0.0",
+ "pytest-cov>=6.0.0,<7.0.0",
"pytest-asyncio>=1.0.0,<1.2.0",
- "pytest-cov>=4.1.0,<5.0.0",
"pytest-xdist>=3.0.0,<4.0.0",
]
extra-args = [
From 918f0945ea9dd0c786bba9af814268d1387a818b Mon Sep 17 00:00:00 2001
From: mehtarac
Date: Mon, 25 Aug 2025 08:34:11 -0700
Subject: [PATCH 054/104] fix: prevent path traversal for message_id in
file_session_manager (#728)
* fix: prevent path traversal for message_id in file_session_manager
* fix: prevent path traversal for message_id in session managers
* fix: prevent path traversal for message_id in session managers
---
src/strands/session/file_session_manager.py | 6 +++++
src/strands/session/s3_session_manager.py | 7 +++++-
.../session/test_file_session_manager.py | 22 +++++++++++++++++--
.../session/test_s3_session_manager.py | 20 ++++++++++++++++-
4 files changed, 51 insertions(+), 4 deletions(-)
diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py
index 9df86e17a..14e71d07c 100644
--- a/src/strands/session/file_session_manager.py
+++ b/src/strands/session/file_session_manager.py
@@ -86,7 +86,13 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) ->
message_id: Index of the message
Returns:
The filename for the message
+
+ Raises:
+ ValueError: If message_id is not an integer.
"""
+ if not isinstance(message_id, int):
+ raise ValueError(f"message_id=<{message_id}> | message id must be an integer")
+
agent_path = self._get_agent_path(session_id, agent_id)
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json")
diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py
index d15e6e3bd..da1735e35 100644
--- a/src/strands/session/s3_session_manager.py
+++ b/src/strands/session/s3_session_manager.py
@@ -113,11 +113,16 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) ->
session_id: ID of the session
agent_id: ID of the agent
message_id: Index of the message
- **kwargs: Additional keyword arguments for future extensibility.
Returns:
The key for the message
+
+ Raises:
+ ValueError: If message_id is not an integer.
"""
+ if not isinstance(message_id, int):
+ raise ValueError(f"message_id=<{message_id}> | message id must be an integer")
+
agent_path = self._get_agent_path(session_id, agent_id)
return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json"
diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py
index a89222b7e..036591924 100644
--- a/tests/strands/session/test_file_session_manager.py
+++ b/tests/strands/session/test_file_session_manager.py
@@ -224,14 +224,14 @@ def test_read_messages_with_new_agent(file_manager, sample_session, sample_agent
file_manager.create_session(sample_session)
file_manager.create_agent(sample_session.session_id, sample_agent)
- result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message")
+ result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999)
assert result is None
def test_read_nonexistent_message(file_manager, sample_session, sample_agent):
"""Test reading a message that doesnt exist."""
- result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message")
+ result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999)
assert result is None
@@ -390,3 +390,21 @@ def test__get_session_path_invalid_session_id(session_id, file_manager):
def test__get_agent_path_invalid_agent_id(agent_id, file_manager):
with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"):
file_manager._get_agent_path("session1", agent_id)
+
+
+@pytest.mark.parametrize(
+ "message_id",
+ [
+ "../../../secret",
+ "../../attack",
+ "../escape",
+ "path/traversal",
+ "not_an_int",
+ None,
+ [],
+ ],
+)
+def test__get_message_path_invalid_message_id(message_id, file_manager):
+ """Test that message_id that is not an integer raises ValueError."""
+ with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"):
+ file_manager._get_message_path("session1", "agent1", message_id)
diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py
index 71bff3050..50fb303f7 100644
--- a/tests/strands/session/test_s3_session_manager.py
+++ b/tests/strands/session/test_s3_session_manager.py
@@ -251,7 +251,7 @@ def test_read_nonexistent_message(s3_manager, sample_session, sample_agent, samp
s3_manager.create_agent(sample_session.session_id, sample_agent)
# Read message
- result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message")
+ result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999)
assert result is None
@@ -356,3 +356,21 @@ def test__get_session_path_invalid_session_id(session_id, s3_manager):
def test__get_agent_path_invalid_agent_id(agent_id, s3_manager):
with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"):
s3_manager._get_agent_path("session1", agent_id)
+
+
+@pytest.mark.parametrize(
+ "message_id",
+ [
+ "../../../secret",
+ "../../attack",
+ "../escape",
+ "path/traversal",
+ "not_an_int",
+ None,
+ [],
+ ],
+)
+def test__get_message_path_invalid_message_id(message_id, s3_manager):
+ """Test that message_id that is not an integer raises ValueError."""
+ with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"):
+ s3_manager._get_message_path("session1", "agent1", message_id)
From f028dc96df64d97f1f5a05be9ec2fc7cd8467a8d Mon Sep 17 00:00:00 2001
From: Nick Clegg
Date: Mon, 25 Aug 2025 13:31:42 -0400
Subject: [PATCH 055/104] fix: Add AgentInput TypeAlias (#738)
---
CONTRIBUTING.md | 2 +-
src/strands/agent/agent.py | 32 ++++++++++++-------
src/strands/session/file_session_manager.py | 2 +-
src/strands/session/s3_session_manager.py | 4 +--
tests/strands/agent/test_agent.py | 2 +-
.../session/test_file_session_manager.py | 2 +-
.../session/test_s3_session_manager.py | 2 +-
7 files changed, 28 insertions(+), 18 deletions(-)
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index add4825fd..93970ed64 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -49,7 +49,7 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as
Alternatively, install development dependencies in a manually created virtual environment:
```bash
- pip install -e ".[dev]" && pip install -e ".[litellm]"
+ pip install -e ".[all]"
```
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index 654b8edce..66099cb1d 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -14,7 +14,19 @@
import logging
import random
from concurrent.futures import ThreadPoolExecutor
-from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
+from typing import (
+ Any,
+ AsyncGenerator,
+ AsyncIterator,
+ Callable,
+ Mapping,
+ Optional,
+ Type,
+ TypeAlias,
+ TypeVar,
+ Union,
+ cast,
+)
from opentelemetry import trace as trace_api
from pydantic import BaseModel
@@ -55,6 +67,8 @@
# TypeVar for generic structured output
T = TypeVar("T", bound=BaseModel)
+AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None
+
# Sentinel class and object to distinguish between explicit None and default parameter value
class _DefaultCallbackHandlerSentinel:
@@ -361,7 +375,7 @@ def tool_names(self) -> list[str]:
all_tools = self.tool_registry.get_all_tools_config()
return list(all_tools.keys())
- def __call__(self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any) -> AgentResult:
+ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.
This method implements the conversational interface with multiple input patterns:
@@ -394,9 +408,7 @@ def execute() -> AgentResult:
future = executor.submit(execute)
return future.result()
- async def invoke_async(
- self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any
- ) -> AgentResult:
+ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.
This method implements the conversational interface with multiple input patterns:
@@ -427,7 +439,7 @@ async def invoke_async(
return cast(AgentResult, event["result"])
- def structured_output(self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None) -> T:
+ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T:
"""This method allows you to get structured output from the agent.
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
@@ -456,9 +468,7 @@ def execute() -> T:
future = executor.submit(execute)
return future.result()
- async def structured_output_async(
- self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None
- ) -> T:
+ async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T:
"""This method allows you to get structured output from the agent.
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
@@ -517,7 +527,7 @@ async def structured_output_async(
async def stream_async(
self,
- prompt: str | list[ContentBlock] | Messages | None = None,
+ prompt: AgentInput = None,
**kwargs: Any,
) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.
@@ -657,7 +667,7 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A
async for event in events:
yield event
- def _convert_prompt_to_messages(self, prompt: str | list[ContentBlock] | Messages | None) -> Messages:
+ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
messages: Messages | None = None
if prompt is not None:
if isinstance(prompt, str):
diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py
index 14e71d07c..491f7ad60 100644
--- a/src/strands/session/file_session_manager.py
+++ b/src/strands/session/file_session_manager.py
@@ -92,7 +92,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) ->
"""
if not isinstance(message_id, int):
raise ValueError(f"message_id=<{message_id}> | message id must be an integer")
-
+
agent_path = self._get_agent_path(session_id, agent_id)
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json")
diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py
index da1735e35..c6ce28d80 100644
--- a/src/strands/session/s3_session_manager.py
+++ b/src/strands/session/s3_session_manager.py
@@ -116,13 +116,13 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) ->
Returns:
The key for the message
-
+
Raises:
ValueError: If message_id is not an integer.
"""
if not isinstance(message_id, int):
raise ValueError(f"message_id=<{message_id}> | message id must be an integer")
-
+
agent_path = self._get_agent_path(session_id, agent_id)
return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json"
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index 01d8f977e..67ea5940a 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -1830,6 +1830,7 @@ def test_agent_with_list_of_message_and_content_block():
with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."):
agent([{"role": "user", "content": [{"text": "hello"}]}, {"text", "hello"}])
+
def test_agent_tool_call_parameter_filtering_integration(mock_randint):
"""Test that tool calls properly filter parameters in message recording."""
mock_randint.return_value = 42
@@ -1861,4 +1862,3 @@ def test_tool(action: str) -> str:
assert '"action": "test_value"' in tool_call_text
assert '"agent"' not in tool_call_text
assert '"extra_param"' not in tool_call_text
-
diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py
index 036591924..f124ddf58 100644
--- a/tests/strands/session/test_file_session_manager.py
+++ b/tests/strands/session/test_file_session_manager.py
@@ -396,7 +396,7 @@ def test__get_agent_path_invalid_agent_id(agent_id, file_manager):
"message_id",
[
"../../../secret",
- "../../attack",
+ "../../attack",
"../escape",
"path/traversal",
"not_an_int",
diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py
index 50fb303f7..c4d6a0154 100644
--- a/tests/strands/session/test_s3_session_manager.py
+++ b/tests/strands/session/test_s3_session_manager.py
@@ -362,7 +362,7 @@ def test__get_agent_path_invalid_agent_id(agent_id, s3_manager):
"message_id",
[
"../../../secret",
- "../../attack",
+ "../../attack",
"../escape",
"path/traversal",
"not_an_int",
From 0fac6480b5d64bf4500c0ea257e0d237a639cd64 Mon Sep 17 00:00:00 2001
From: Nick Clegg
Date: Tue, 26 Aug 2025 10:39:56 -0400
Subject: [PATCH 056/104] fix: Move AgentInput to types submodule (#746)
---
src/strands/agent/agent.py | 4 +---
src/strands/types/agent.py | 10 ++++++++++
2 files changed, 11 insertions(+), 3 deletions(-)
create mode 100644 src/strands/types/agent.py
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index 66099cb1d..e2aed9d2b 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -22,7 +22,6 @@
Mapping,
Optional,
Type,
- TypeAlias,
TypeVar,
Union,
cast,
@@ -51,6 +50,7 @@
from ..tools.executors._executor import ToolExecutor
from ..tools.registry import ToolRegistry
from ..tools.watcher import ToolWatcher
+from ..types.agent import AgentInput
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
from ..types.tools import ToolResult, ToolUse
@@ -67,8 +67,6 @@
# TypeVar for generic structured output
T = TypeVar("T", bound=BaseModel)
-AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None
-
# Sentinel class and object to distinguish between explicit None and default parameter value
class _DefaultCallbackHandlerSentinel:
diff --git a/src/strands/types/agent.py b/src/strands/types/agent.py
new file mode 100644
index 000000000..151c88f89
--- /dev/null
+++ b/src/strands/types/agent.py
@@ -0,0 +1,10 @@
+"""Agent-related type definitions for the SDK.
+
+This module defines the types used for an Agent.
+"""
+
+from typing import TypeAlias
+
+from .content import ContentBlock, Messages
+
+AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None
From aa03b3dfffbc98303bba8f57a19e98b1bdb239af Mon Sep 17 00:00:00 2001
From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com>
Date: Wed, 27 Aug 2025 09:14:45 -0400
Subject: [PATCH 057/104] feat: Implement typed events internally (#745)
Step 1/N for implementing typed-events; first just preserve the existing behaviors with no changes to the public api.
A follow-up change will update how we invoke callbacks and pass invocation_state around, while this one just adds typed classes for events internally.
---------
Co-authored-by: Mackenzie Zastrow
---
src/strands/agent/agent.py | 3 +-
src/strands/event_loop/event_loop.py | 35 ++-
src/strands/event_loop/streaming.py | 50 ++--
src/strands/types/_events.py | 238 ++++++++++++++++++
.../strands/agent/hooks/test_agent_events.py | 159 ++++++++++++
tests/strands/agent/test_agent.py | 129 +++++-----
tests/strands/event_loop/test_streaming.py | 9 +
7 files changed, 529 insertions(+), 94 deletions(-)
create mode 100644 src/strands/types/_events.py
create mode 100644 tests/strands/agent/hooks/test_agent_events.py
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index e2aed9d2b..8233c4bfe 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -50,6 +50,7 @@
from ..tools.executors._executor import ToolExecutor
from ..tools.registry import ToolRegistry
from ..tools.watcher import ToolWatcher
+from ..types._events import InitEventLoopEvent
from ..types.agent import AgentInput
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
@@ -604,7 +605,7 @@ async def _run_loop(
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
try:
- yield {"callback": {"init_event_loop": True, **invocation_state}}
+ yield InitEventLoopEvent(invocation_state)
for message in messages:
self._append_message(message)
diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py
index 524ecc3e8..a166902eb 100644
--- a/src/strands/event_loop/event_loop.py
+++ b/src/strands/event_loop/event_loop.py
@@ -25,6 +25,15 @@
from ..telemetry.metrics import Trace
from ..telemetry.tracer import get_tracer
from ..tools._validator import validate_and_prepare_tools
+from ..types._events import (
+ EventLoopStopEvent,
+ EventLoopThrottleEvent,
+ ForceStopEvent,
+ ModelMessageEvent,
+ StartEvent,
+ StartEventLoopEvent,
+ ToolResultMessageEvent,
+)
from ..types.content import Message
from ..types.exceptions import (
ContextWindowOverflowException,
@@ -91,8 +100,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes)
invocation_state["event_loop_cycle_trace"] = cycle_trace
- yield {"callback": {"start": True}}
- yield {"callback": {"start_event_loop": True}}
+ yield StartEvent()
+ yield StartEventLoopEvent()
# Create tracer span for this event loop cycle
tracer = get_tracer()
@@ -175,7 +184,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
if isinstance(e, ModelThrottledException):
if attempt + 1 == MAX_ATTEMPTS:
- yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
+ yield ForceStopEvent(reason=e)
raise e
logger.debug(
@@ -189,7 +198,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
time.sleep(current_delay)
current_delay = min(current_delay * 2, MAX_DELAY)
- yield {"callback": {"event_loop_throttled_delay": current_delay, **invocation_state}}
+ yield EventLoopThrottleEvent(delay=current_delay, invocation_state=invocation_state)
else:
raise e
@@ -201,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
# Add the response message to the conversation
agent.messages.append(message)
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
- yield {"callback": {"message": message}}
+ yield ModelMessageEvent(message=message)
# Update metrics
agent.event_loop_metrics.update_usage(usage)
@@ -235,8 +244,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
cycle_start_time=cycle_start_time,
invocation_state=invocation_state,
)
- async for event in events:
- yield event
+ async for typed_event in events:
+ yield typed_event
return
@@ -264,11 +273,11 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
tracer.end_span_with_error(cycle_span, str(e), e)
# Handle any other exceptions
- yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
+ yield ForceStopEvent(reason=e)
logger.exception("cycle failed")
raise EventLoopException(e, invocation_state["request_state"]) from e
- yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
+ yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
@@ -295,7 +304,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -
recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id)
cycle_trace.add_child(recursive_trace)
- yield {"callback": {"start": True}}
+ yield StartEvent()
events = event_loop_cycle(agent=agent, invocation_state=invocation_state)
async for event in events:
@@ -339,7 +348,7 @@ async def _handle_tool_execution(
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
if not tool_uses:
- yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
+ yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
return
tool_events = agent.tool_executor._execute(
@@ -358,7 +367,7 @@ async def _handle_tool_execution(
agent.messages.append(tool_result_message)
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
- yield {"callback": {"message": tool_result_message}}
+ yield ToolResultMessageEvent(message=message)
if cycle_span:
tracer = get_tracer()
@@ -366,7 +375,7 @@ async def _handle_tool_execution(
if invocation_state["request_state"].get("stop_event_loop", False):
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
- yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
+ yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
return
events = recurse_event_loop(agent=agent, invocation_state=invocation_state)
diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py
index 1f8c260a4..7507c6d75 100644
--- a/src/strands/event_loop/streaming.py
+++ b/src/strands/event_loop/streaming.py
@@ -5,6 +5,16 @@
from typing import Any, AsyncGenerator, AsyncIterable, Optional
from ..models.model import Model
+from ..types._events import (
+ ModelStopReason,
+ ModelStreamChunkEvent,
+ ModelStreamEvent,
+ ReasoningSignatureStreamEvent,
+ ReasoningTextStreamEvent,
+ TextStreamEvent,
+ ToolUseStreamEvent,
+ TypedEvent,
+)
from ..types.content import ContentBlock, Message, Messages
from ..types.streaming import (
ContentBlockDeltaEvent,
@@ -105,7 +115,7 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]:
def handle_content_block_delta(
event: ContentBlockDeltaEvent, state: dict[str, Any]
-) -> tuple[dict[str, Any], dict[str, Any]]:
+) -> tuple[dict[str, Any], ModelStreamEvent]:
"""Handles content block delta updates by appending text, tool input, or reasoning content to the state.
Args:
@@ -117,18 +127,18 @@ def handle_content_block_delta(
"""
delta_content = event["delta"]
- callback_event = {}
+ typed_event: ModelStreamEvent = ModelStreamEvent({})
if "toolUse" in delta_content:
if "input" not in state["current_tool_use"]:
state["current_tool_use"]["input"] = ""
state["current_tool_use"]["input"] += delta_content["toolUse"]["input"]
- callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]}
+ typed_event = ToolUseStreamEvent(delta_content, state["current_tool_use"])
elif "text" in delta_content:
state["text"] += delta_content["text"]
- callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content}
+ typed_event = TextStreamEvent(text=delta_content["text"], delta=delta_content)
elif "reasoningContent" in delta_content:
if "text" in delta_content["reasoningContent"]:
@@ -136,24 +146,22 @@ def handle_content_block_delta(
state["reasoningText"] = ""
state["reasoningText"] += delta_content["reasoningContent"]["text"]
- callback_event["callback"] = {
- "reasoningText": delta_content["reasoningContent"]["text"],
- "delta": delta_content,
- "reasoning": True,
- }
+ typed_event = ReasoningTextStreamEvent(
+ reasoning_text=delta_content["reasoningContent"]["text"],
+ delta=delta_content,
+ )
elif "signature" in delta_content["reasoningContent"]:
if "signature" not in state:
state["signature"] = ""
state["signature"] += delta_content["reasoningContent"]["signature"]
- callback_event["callback"] = {
- "reasoning_signature": delta_content["reasoningContent"]["signature"],
- "delta": delta_content,
- "reasoning": True,
- }
+ typed_event = ReasoningSignatureStreamEvent(
+ reasoning_signature=delta_content["reasoningContent"]["signature"],
+ delta=delta_content,
+ )
- return state, callback_event
+ return state, typed_event
def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
@@ -251,7 +259,7 @@ def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]:
return usage, metrics
-async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[dict[str, Any], None]:
+async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[TypedEvent, None]:
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.
Args:
@@ -274,14 +282,14 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
metrics: Metrics = Metrics(latencyMs=0)
async for chunk in chunks:
- yield {"callback": {"event": chunk}}
+ yield ModelStreamChunkEvent(chunk=chunk)
if "messageStart" in chunk:
state["message"] = handle_message_start(chunk["messageStart"], state["message"])
elif "contentBlockStart" in chunk:
state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"])
elif "contentBlockDelta" in chunk:
- state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state)
- yield callback_event
+ state, typed_event = handle_content_block_delta(chunk["contentBlockDelta"], state)
+ yield typed_event
elif "contentBlockStop" in chunk:
state = handle_content_block_stop(state)
elif "messageStop" in chunk:
@@ -291,7 +299,7 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
elif "redactContent" in chunk:
handle_redact_content(chunk["redactContent"], state)
- yield {"stop": (stop_reason, state["message"], usage, metrics)}
+ yield ModelStopReason(stop_reason=stop_reason, message=state["message"], usage=usage, metrics=metrics)
async def stream_messages(
@@ -299,7 +307,7 @@ async def stream_messages(
system_prompt: Optional[str],
messages: Messages,
tool_specs: list[ToolSpec],
-) -> AsyncGenerator[dict[str, Any], None]:
+) -> AsyncGenerator[TypedEvent, None]:
"""Streams messages to the model and processes the response.
Args:
diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py
new file mode 100644
index 000000000..1bddc5877
--- /dev/null
+++ b/src/strands/types/_events.py
@@ -0,0 +1,238 @@
+"""event system for the Strands Agents framework.
+
+This module defines the event types that are emitted during agent execution,
+providing a structured way to observe to different events of the event loop and
+agent lifecycle.
+"""
+
+from typing import TYPE_CHECKING, Any
+
+from ..telemetry import EventLoopMetrics
+from .content import Message
+from .event_loop import Metrics, StopReason, Usage
+from .streaming import ContentBlockDelta, StreamEvent
+
+if TYPE_CHECKING:
+ pass
+
+
+class TypedEvent(dict):
+ """Base class for all typed events in the agent system."""
+
+ def __init__(self, data: dict[str, Any] | None = None) -> None:
+ """Initialize the typed event with optional data.
+
+ Args:
+ data: Optional dictionary of event data to initialize with
+ """
+ super().__init__(data or {})
+
+
+class InitEventLoopEvent(TypedEvent):
+ """Event emitted at the very beginning of agent execution.
+
+ This event is fired before any processing begins and provides access to the
+ initial invocation state.
+
+ Args:
+ invocation_state: The invocation state passed into the request
+ """
+
+ def __init__(self, invocation_state: dict) -> None:
+ """Initialize the event loop initialization event."""
+ super().__init__({"callback": {"init_event_loop": True, **invocation_state}})
+
+
+class StartEvent(TypedEvent):
+ """Event emitted at the start of each event loop cycle.
+
+ !!deprecated!!
+ Use StartEventLoopEvent instead.
+
+ This event events the beginning of a new processing cycle within the agent's
+ event loop. It's fired before model invocation and tool execution begin.
+ """
+
+ def __init__(self) -> None:
+ """Initialize the event loop start event."""
+ super().__init__({"callback": {"start": True}})
+
+
+class StartEventLoopEvent(TypedEvent):
+ """Event emitted when the event loop cycle begins processing.
+
+ This event is fired after StartEvent and indicates that the event loop
+ has begun its core processing logic, including model invocation preparation.
+ """
+
+ def __init__(self) -> None:
+ """Initialize the event loop processing start event."""
+ super().__init__({"callback": {"start_event_loop": True}})
+
+
+class ModelStreamChunkEvent(TypedEvent):
+ """Event emitted during model response streaming for each raw chunk."""
+
+ def __init__(self, chunk: StreamEvent) -> None:
+ """Initialize with streaming delta data from the model.
+
+ Args:
+ chunk: Incremental streaming data from the model response
+ """
+ super().__init__({"callback": {"event": chunk}})
+
+
+class ModelStreamEvent(TypedEvent):
+ """Event emitted during model response streaming.
+
+ This event is fired when the model produces streaming output during response
+ generation.
+ """
+
+ def __init__(self, delta_data: dict[str, Any]) -> None:
+ """Initialize with streaming delta data from the model.
+
+ Args:
+ delta_data: Incremental streaming data from the model response
+ """
+ super().__init__(delta_data)
+
+
+class ToolUseStreamEvent(ModelStreamEvent):
+ """Event emitted during tool use input streaming."""
+
+ def __init__(self, delta: ContentBlockDelta, current_tool_use: dict[str, Any]) -> None:
+ """Initialize with delta and current tool use state."""
+ super().__init__({"callback": {"delta": delta, "current_tool_use": current_tool_use}})
+
+
+class TextStreamEvent(ModelStreamEvent):
+ """Event emitted during text content streaming."""
+
+ def __init__(self, delta: ContentBlockDelta, text: str) -> None:
+ """Initialize with delta and text content."""
+ super().__init__({"callback": {"data": text, "delta": delta}})
+
+
+class ReasoningTextStreamEvent(ModelStreamEvent):
+ """Event emitted during reasoning text streaming."""
+
+ def __init__(self, delta: ContentBlockDelta, reasoning_text: str | None) -> None:
+ """Initialize with delta and reasoning text."""
+ super().__init__({"callback": {"reasoningText": reasoning_text, "delta": delta, "reasoning": True}})
+
+
+class ReasoningSignatureStreamEvent(ModelStreamEvent):
+ """Event emitted during reasoning signature streaming."""
+
+ def __init__(self, delta: ContentBlockDelta, reasoning_signature: str | None) -> None:
+ """Initialize with delta and reasoning signature."""
+ super().__init__({"callback": {"reasoning_signature": reasoning_signature, "delta": delta, "reasoning": True}})
+
+
+class ModelStopReason(TypedEvent):
+ """Event emitted during reasoning signature streaming."""
+
+ def __init__(
+ self,
+ stop_reason: StopReason,
+ message: Message,
+ usage: Usage,
+ metrics: Metrics,
+ ) -> None:
+ """Initialize with the final execution results.
+
+ Args:
+ stop_reason: Why the agent execution stopped
+ message: Final message from the model
+ usage: Usage information from the model
+ metrics: Execution metrics and performance data
+ """
+ super().__init__({"stop": (stop_reason, message, usage, metrics)})
+
+
+class EventLoopStopEvent(TypedEvent):
+ """Event emitted when the agent execution completes normally."""
+
+ def __init__(
+ self,
+ stop_reason: StopReason,
+ message: Message,
+ metrics: "EventLoopMetrics",
+ request_state: Any,
+ ) -> None:
+ """Initialize with the final execution results.
+
+ Args:
+ stop_reason: Why the agent execution stopped
+ message: Final message from the model
+ metrics: Execution metrics and performance data
+ request_state: Final state of the agent execution
+ """
+ super().__init__({"stop": (stop_reason, message, metrics, request_state)})
+
+
+class EventLoopThrottleEvent(TypedEvent):
+ """Event emitted when the event loop is throttled due to rate limiting."""
+
+ def __init__(self, delay: int, invocation_state: dict[str, Any]) -> None:
+ """Initialize with the throttle delay duration.
+
+ Args:
+ delay: Delay in seconds before the next retry attempt
+ invocation_state: The invocation state passed into the request
+ """
+ super().__init__({"callback": {"event_loop_throttled_delay": delay, **invocation_state}})
+
+
+class ModelMessageEvent(TypedEvent):
+ """Event emitted when the model invocation has completed.
+
+ This event is fired whenever the model generates a response message that
+ gets added to the conversation history.
+ """
+
+ def __init__(self, message: Message) -> None:
+ """Initialize with the model-generated message.
+
+ Args:
+ message: The response message from the model
+ """
+ super().__init__({"callback": {"message": message}})
+
+
+class ToolResultMessageEvent(TypedEvent):
+ """Event emitted when tool results are formatted as a message.
+
+ This event is fired when tool execution results are converted into a
+ message format to be added to the conversation history. It provides
+ access to the formatted message containing tool results.
+ """
+
+ def __init__(self, message: Any) -> None:
+ """Initialize with the model-generated message.
+
+ Args:
+ message: Message containing tool results for conversation history
+ """
+ super().__init__({"callback": {"message": message}})
+
+
+class ForceStopEvent(TypedEvent):
+ """Event emitted when the agent execution is forcibly stopped, either by a tool or by an exception."""
+
+ def __init__(self, reason: str | Exception) -> None:
+ """Initialize with the reason for forced stop.
+
+ Args:
+ reason: String description or exception that caused the forced stop
+ """
+ super().__init__(
+ {
+ "callback": {
+ "force_stop": True,
+ "force_stop_reason": str(reason),
+ # "force_stop_reason_exception": reason if reason and isinstance(reason, Exception) else MISSING,
+ }
+ }
+ )
diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py
new file mode 100644
index 000000000..d63dd97d4
--- /dev/null
+++ b/tests/strands/agent/hooks/test_agent_events.py
@@ -0,0 +1,159 @@
+import unittest.mock
+from unittest.mock import ANY, MagicMock, call
+
+import pytest
+
+import strands
+from strands import Agent
+from strands.agent import AgentResult
+from strands.types._events import TypedEvent
+from strands.types.exceptions import ModelThrottledException
+from tests.fixtures.mocked_model_provider import MockedModelProvider
+
+
+@pytest.fixture
+def mock_time():
+ with unittest.mock.patch.object(strands.event_loop.event_loop, "time") as mock:
+ yield mock
+
+
+@pytest.mark.asyncio
+async def test_stream_async_e2e(alist, mock_time):
+ @strands.tool
+ def fake_tool(agent: Agent):
+ return "Done!"
+
+ mock_provider = MockedModelProvider(
+ [
+ {"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"},
+ {"role": "assistant", "content": [{"text": "Okay invoking tool!"}]},
+ {
+ "role": "assistant",
+ "content": [{"toolUse": {"name": "fake_tool", "toolUseId": "123", "input": {}}}],
+ },
+ {"role": "assistant", "content": [{"text": "I invoked a tool!"}]},
+ ]
+ )
+ model = MagicMock()
+ model.stream.side_effect = [
+ ModelThrottledException("ThrottlingException | ConverseStream"),
+ ModelThrottledException("ThrottlingException | ConverseStream"),
+ mock_provider.stream([]),
+ ]
+
+ mock_callback = unittest.mock.Mock()
+ agent = Agent(model=model, tools=[fake_tool], callback_handler=mock_callback)
+
+ stream = agent.stream_async("Do the stuff", arg1=1013)
+
+ # Base object with common properties
+ throttle_props = {
+ "agent": ANY,
+ "event_loop_cycle_id": ANY,
+ "event_loop_cycle_span": ANY,
+ "event_loop_cycle_trace": ANY,
+ "arg1": 1013,
+ "request_state": {},
+ }
+
+ tru_events = await alist(stream)
+ exp_events = [
+ {"arg1": 1013, "init_event_loop": True},
+ {"start": True},
+ {"start_event_loop": True},
+ {"event_loop_throttled_delay": 8, **throttle_props},
+ {"event_loop_throttled_delay": 16, **throttle_props},
+ {"event": {"messageStart": {"role": "assistant"}}},
+ {"event": {"redactContent": {"redactUserContentMessage": "BLOCKED!"}}},
+ {"event": {"contentBlockStart": {"start": {}}}},
+ {"event": {"contentBlockDelta": {"delta": {"text": "INPUT BLOCKED!"}}}},
+ {
+ "agent": ANY,
+ "arg1": 1013,
+ "data": "INPUT BLOCKED!",
+ "delta": {"text": "INPUT BLOCKED!"},
+ "event_loop_cycle_id": ANY,
+ "event_loop_cycle_span": ANY,
+ "event_loop_cycle_trace": ANY,
+ "request_state": {},
+ },
+ {"event": {"contentBlockStop": {}}},
+ {"event": {"messageStop": {"stopReason": "guardrail_intervened"}}},
+ {"message": {"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}},
+ {
+ "result": AgentResult(
+ stop_reason="guardrail_intervened",
+ message={"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"},
+ metrics=ANY,
+ state={},
+ ),
+ },
+ ]
+ assert tru_events == exp_events
+
+ exp_calls = [call(**event) for event in exp_events]
+ act_calls = mock_callback.call_args_list
+ assert act_calls == exp_calls
+
+ # Ensure that all events coming out of the agent are *not* typed events
+ typed_events = [event for event in tru_events if isinstance(event, TypedEvent)]
+ assert typed_events == []
+
+
+@pytest.mark.asyncio
+async def test_event_loop_cycle_text_response_throttling_early_end(
+ agenerator,
+ alist,
+ mock_time,
+):
+ model = MagicMock()
+ model.stream.side_effect = [
+ ModelThrottledException("ThrottlingException | ConverseStream"),
+ ModelThrottledException("ThrottlingException | ConverseStream"),
+ ModelThrottledException("ThrottlingException | ConverseStream"),
+ ModelThrottledException("ThrottlingException | ConverseStream"),
+ ModelThrottledException("ThrottlingException | ConverseStream"),
+ ModelThrottledException("ThrottlingException | ConverseStream"),
+ ]
+
+ mock_callback = unittest.mock.Mock()
+ with pytest.raises(ModelThrottledException):
+ agent = Agent(model=model, callback_handler=mock_callback)
+
+ # Because we're throwing an exception, we manually collect the items here
+ tru_events = []
+ stream = agent.stream_async("Do the stuff", arg1=1013)
+ async for event in stream:
+ tru_events.append(event)
+
+ # Base object with common properties
+ common_props = {
+ "agent": ANY,
+ "event_loop_cycle_id": ANY,
+ "event_loop_cycle_span": ANY,
+ "event_loop_cycle_trace": ANY,
+ "arg1": 1013,
+ "request_state": {},
+ }
+
+ exp_events = [
+ {"init_event_loop": True, "arg1": 1013},
+ {"start": True},
+ {"start_event_loop": True},
+ {"event_loop_throttled_delay": 8, **common_props},
+ {"event_loop_throttled_delay": 16, **common_props},
+ {"event_loop_throttled_delay": 32, **common_props},
+ {"event_loop_throttled_delay": 64, **common_props},
+ {"event_loop_throttled_delay": 128, **common_props},
+ {"force_stop": True, "force_stop_reason": "ThrottlingException | ConverseStream"},
+ ]
+
+ assert tru_events == exp_events
+
+ exp_calls = [call(**event) for event in exp_events]
+ act_calls = mock_callback.call_args_list
+ assert act_calls == exp_calls
+
+ # Ensure that all events coming out of the agent are *not* typed events
+ typed_events = [event for event in tru_events if isinstance(event, TypedEvent)]
+ assert typed_events == []
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index 67ea5940a..a4a8af09a 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -668,62 +668,71 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator):
)
agent("test")
- callback_handler.assert_has_calls(
- [
- unittest.mock.call(init_event_loop=True),
- unittest.mock.call(start=True),
- unittest.mock.call(start_event_loop=True),
- unittest.mock.call(
- event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}
- ),
- unittest.mock.call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}),
- unittest.mock.call(
- agent=agent,
- current_tool_use={"toolUseId": "123", "name": "test", "input": {}},
- delta={"toolUse": {"input": '{"value"}'}},
- event_loop_cycle_id=unittest.mock.ANY,
- event_loop_cycle_span=unittest.mock.ANY,
- event_loop_cycle_trace=unittest.mock.ANY,
- request_state={},
- ),
- unittest.mock.call(event={"contentBlockStop": {}}),
- unittest.mock.call(event={"contentBlockStart": {"start": {}}}),
- unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}),
- unittest.mock.call(
- agent=agent,
- delta={"reasoningContent": {"text": "value"}},
- event_loop_cycle_id=unittest.mock.ANY,
- event_loop_cycle_span=unittest.mock.ANY,
- event_loop_cycle_trace=unittest.mock.ANY,
- reasoning=True,
- reasoningText="value",
- request_state={},
- ),
- unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}),
- unittest.mock.call(
- agent=agent,
- delta={"reasoningContent": {"signature": "value"}},
- event_loop_cycle_id=unittest.mock.ANY,
- event_loop_cycle_span=unittest.mock.ANY,
- event_loop_cycle_trace=unittest.mock.ANY,
- reasoning=True,
- reasoning_signature="value",
- request_state={},
- ),
- unittest.mock.call(event={"contentBlockStop": {}}),
- unittest.mock.call(event={"contentBlockStart": {"start": {}}}),
- unittest.mock.call(event={"contentBlockDelta": {"delta": {"text": "value"}}}),
- unittest.mock.call(
- agent=agent,
- data="value",
- delta={"text": "value"},
- event_loop_cycle_id=unittest.mock.ANY,
- event_loop_cycle_span=unittest.mock.ANY,
- event_loop_cycle_trace=unittest.mock.ANY,
- request_state={},
- ),
- unittest.mock.call(event={"contentBlockStop": {}}),
- unittest.mock.call(
+ assert callback_handler.call_args_list == [
+ unittest.mock.call(init_event_loop=True),
+ unittest.mock.call(start=True),
+ unittest.mock.call(start_event_loop=True),
+ unittest.mock.call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}),
+ unittest.mock.call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}),
+ unittest.mock.call(
+ agent=agent,
+ current_tool_use={"toolUseId": "123", "name": "test", "input": {}},
+ delta={"toolUse": {"input": '{"value"}'}},
+ event_loop_cycle_id=unittest.mock.ANY,
+ event_loop_cycle_span=unittest.mock.ANY,
+ event_loop_cycle_trace=unittest.mock.ANY,
+ request_state={},
+ ),
+ unittest.mock.call(event={"contentBlockStop": {}}),
+ unittest.mock.call(event={"contentBlockStart": {"start": {}}}),
+ unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}),
+ unittest.mock.call(
+ agent=agent,
+ delta={"reasoningContent": {"text": "value"}},
+ event_loop_cycle_id=unittest.mock.ANY,
+ event_loop_cycle_span=unittest.mock.ANY,
+ event_loop_cycle_trace=unittest.mock.ANY,
+ reasoning=True,
+ reasoningText="value",
+ request_state={},
+ ),
+ unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}),
+ unittest.mock.call(
+ agent=agent,
+ delta={"reasoningContent": {"signature": "value"}},
+ event_loop_cycle_id=unittest.mock.ANY,
+ event_loop_cycle_span=unittest.mock.ANY,
+ event_loop_cycle_trace=unittest.mock.ANY,
+ reasoning=True,
+ reasoning_signature="value",
+ request_state={},
+ ),
+ unittest.mock.call(event={"contentBlockStop": {}}),
+ unittest.mock.call(event={"contentBlockStart": {"start": {}}}),
+ unittest.mock.call(event={"contentBlockDelta": {"delta": {"text": "value"}}}),
+ unittest.mock.call(
+ agent=agent,
+ data="value",
+ delta={"text": "value"},
+ event_loop_cycle_id=unittest.mock.ANY,
+ event_loop_cycle_span=unittest.mock.ANY,
+ event_loop_cycle_trace=unittest.mock.ANY,
+ request_state={},
+ ),
+ unittest.mock.call(event={"contentBlockStop": {}}),
+ unittest.mock.call(
+ message={
+ "role": "assistant",
+ "content": [
+ {"toolUse": {"toolUseId": "123", "name": "test", "input": {}}},
+ {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}},
+ {"text": "value"},
+ ],
+ },
+ ),
+ unittest.mock.call(
+ result=AgentResult(
+ stop_reason="end_turn",
message={
"role": "assistant",
"content": [
@@ -732,9 +741,11 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator):
{"text": "value"},
],
},
- ),
- ],
- )
+ metrics=unittest.mock.ANY,
+ state={},
+ )
+ ),
+ ]
@pytest.mark.asyncio
diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py
index 7760c498a..fd9548dae 100644
--- a/tests/strands/event_loop/test_streaming.py
+++ b/tests/strands/event_loop/test_streaming.py
@@ -4,6 +4,7 @@
import strands
import strands.event_loop
+from strands.types._events import TypedEvent
from strands.types.streaming import (
ContentBlockDeltaEvent,
ContentBlockStartEvent,
@@ -562,6 +563,10 @@ async def test_process_stream(response, exp_events, agenerator, alist):
tru_events = await alist(stream)
assert tru_events == exp_events
+ # Ensure that we're getting typed events coming out of process_stream
+ non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)]
+ assert non_typed_events == []
+
@pytest.mark.asyncio
async def test_stream_messages(agenerator, alist):
@@ -624,3 +629,7 @@ async def test_stream_messages(agenerator, alist):
None,
"test prompt",
)
+
+ # Ensure that we're getting typed events coming out of process_stream
+ non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)]
+ assert non_typed_events == []
From d9f8d8a76c80eb5296b4a60f778d62192241c128 Mon Sep 17 00:00:00 2001
From: Patrick Gray
Date: Thu, 28 Aug 2025 09:47:17 -0400
Subject: [PATCH 058/104] summarization manager - add summary prompt to
messages (#698)
* summarization manager - add summary prompt to messages
* summarize conversation - assistant to user role
* fix test
* add period
---
.../conversation_manager/summarizing_conversation_manager.py | 5 ++---
tests/strands/agent/test_summarizing_conversation_manager.py | 4 ++--
.../test_summarizing_conversation_manager_integration.py | 4 ++--
3 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py
index 60e832215..b08b6853e 100644
--- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py
+++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py
@@ -1,7 +1,7 @@
"""Summarizing conversation history management with configurable options."""
import logging
-from typing import TYPE_CHECKING, Any, List, Optional
+from typing import TYPE_CHECKING, Any, List, Optional, cast
from typing_extensions import override
@@ -201,8 +201,7 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message:
# Use the agent to generate summary with rich content (can use tools if needed)
result = summarization_agent("Please summarize this conversation.")
-
- return result.message
+ return cast(Message, {**result.message, "role": "user"})
finally:
# Restore original agent state
diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py
index a97104412..6003a1710 100644
--- a/tests/strands/agent/test_summarizing_conversation_manager.py
+++ b/tests/strands/agent/test_summarizing_conversation_manager.py
@@ -99,7 +99,7 @@ def test_reduce_context_with_summarization(summarizing_manager, mock_agent):
assert len(mock_agent.messages) == 4
# First message should be the summary
- assert mock_agent.messages[0]["role"] == "assistant"
+ assert mock_agent.messages[0]["role"] == "user"
first_content = mock_agent.messages[0]["content"][0]
assert "text" in first_content and "This is a summary of the conversation." in first_content["text"]
@@ -438,7 +438,7 @@ def test_reduce_context_tool_pair_adjustment_works_with_forward_search():
assert len(mock_agent.messages) == 2
# First message should be the summary
- assert mock_agent.messages[0]["role"] == "assistant"
+ assert mock_agent.messages[0]["role"] == "user"
summary_content = mock_agent.messages[0]["content"][0]
assert "text" in summary_content and "This is a summary of the conversation." in summary_content["text"]
diff --git a/tests_integ/test_summarizing_conversation_manager_integration.py b/tests_integ/test_summarizing_conversation_manager_integration.py
index 719520b8d..b205c723f 100644
--- a/tests_integ/test_summarizing_conversation_manager_integration.py
+++ b/tests_integ/test_summarizing_conversation_manager_integration.py
@@ -160,7 +160,7 @@ def test_summarization_with_context_overflow(model):
# First message should be the summary (assistant message)
summary_message = agent.messages[0]
- assert summary_message["role"] == "assistant"
+ assert summary_message["role"] == "user"
assert len(summary_message["content"]) > 0
# Verify the summary contains actual text content
@@ -362,7 +362,7 @@ def test_dedicated_summarization_agent(model, summarization_model):
# Get the summary message
summary_message = agent.messages[0]
- assert summary_message["role"] == "assistant"
+ assert summary_message["role"] == "user"
# Extract summary text
summary_text = None
From 6dadbce85bbfef200bf3283810597895aa7ad2dc Mon Sep 17 00:00:00 2001
From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com>
Date: Thu, 28 Aug 2025 10:09:30 -0400
Subject: [PATCH 059/104] feat: Use TypedEvent inheritance for callback
behavior (#755)
Move away from "callback" nested properties in the dict and explicitly passing invocation_state migrating to behaviors on the TypedEvent:
- TypedEvent.is_callback_event for determining if an event should be yielded and or invoked in the callback
- TypedEvent.prepare for taking in invocation_state
Customers still only get dictionaries, as we decided that this will remain an implementation detail for the time being, but this makes the events typed all the way up until *just* before we yield events back to the caller
---------
Co-authored-by: Mackenzie Zastrow
---
src/strands/agent/agent.py | 31 +--
src/strands/event_loop/event_loop.py | 22 +-
src/strands/tools/executors/_executor.py | 23 +-
src/strands/tools/executors/concurrent.py | 7 +-
src/strands/tools/executors/sequential.py | 7 +-
src/strands/types/_events.py | 145 +++++++++++--
tests/strands/agent/test_agent.py | 15 +-
tests/strands/event_loop/test_event_loop.py | 2 +-
tests/strands/event_loop/test_streaming.py | 196 +++++++-----------
.../tools/executors/test_concurrent.py | 16 +-
.../strands/tools/executors/test_executor.py | 28 +--
.../tools/executors/test_sequential.py | 11 +-
12 files changed, 288 insertions(+), 215 deletions(-)
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index 8233c4bfe..1e64f5adb 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -50,7 +50,7 @@
from ..tools.executors._executor import ToolExecutor
from ..tools.registry import ToolRegistry
from ..tools.watcher import ToolWatcher
-from ..types._events import InitEventLoopEvent
+from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
from ..types.agent import AgentInput
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
@@ -576,13 +576,16 @@ async def stream_async(
events = self._run_loop(messages, invocation_state=kwargs)
async for event in events:
- if "callback" in event:
- callback_handler(**event["callback"])
- yield event["callback"]
+ event.prepare(invocation_state=kwargs)
+
+ if event.is_callback_event:
+ as_dict = event.as_dict()
+ callback_handler(**as_dict)
+ yield as_dict
result = AgentResult(*event["stop"])
callback_handler(result=result)
- yield {"result": result}
+ yield AgentResultEvent(result=result).as_dict()
self._end_agent_trace_span(response=result)
@@ -590,9 +593,7 @@ async def stream_async(
self._end_agent_trace_span(error=e)
raise
- async def _run_loop(
- self, messages: Messages, invocation_state: dict[str, Any]
- ) -> AsyncGenerator[dict[str, Any], None]:
+ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
"""Execute the agent's event loop with the given message and parameters.
Args:
@@ -605,7 +606,7 @@ async def _run_loop(
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
try:
- yield InitEventLoopEvent(invocation_state)
+ yield InitEventLoopEvent()
for message in messages:
self._append_message(message)
@@ -616,13 +617,13 @@ async def _run_loop(
# Signal from the model provider that the message sent by the user should be redacted,
# likely due to a guardrail.
if (
- event.get("callback")
- and event["callback"].get("event")
- and event["callback"]["event"].get("redactContent")
- and event["callback"]["event"]["redactContent"].get("redactUserContentMessage")
+ isinstance(event, ModelStreamChunkEvent)
+ and event.chunk
+ and event.chunk.get("redactContent")
+ and event.chunk["redactContent"].get("redactUserContentMessage")
):
self.messages[-1]["content"] = [
- {"text": event["callback"]["event"]["redactContent"]["redactUserContentMessage"]}
+ {"text": str(event.chunk["redactContent"]["redactUserContentMessage"])}
]
if self._session_manager:
self._session_manager.redact_latest_message(self.messages[-1], self)
@@ -632,7 +633,7 @@ async def _run_loop(
self.conversation_manager.apply_management(self)
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
- async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
+ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
"""Execute the event loop cycle with retry logic for context window limits.
This internal method handles the execution of the event loop cycle and implements
diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py
index a166902eb..a99ecc8a6 100644
--- a/src/strands/event_loop/event_loop.py
+++ b/src/strands/event_loop/event_loop.py
@@ -30,9 +30,11 @@
EventLoopThrottleEvent,
ForceStopEvent,
ModelMessageEvent,
+ ModelStopReason,
StartEvent,
StartEventLoopEvent,
ToolResultMessageEvent,
+ TypedEvent,
)
from ..types.content import Message
from ..types.exceptions import (
@@ -56,7 +58,7 @@
MAX_DELAY = 240 # 4 minutes
-async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
+async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
"""Execute a single cycle of the event loop.
This core function processes a single conversation turn, handling model inference, tool execution, and error
@@ -139,17 +141,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
)
try:
- # TODO: To maintain backwards compatibility, we need to combine the stream event with invocation_state
- # before yielding to the callback handler. This will be revisited when migrating to strongly
- # typed events.
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs):
- if "callback" in event:
- yield {
- "callback": {
- **event["callback"],
- **(invocation_state if "delta" in event["callback"] else {}),
- }
- }
+ if not isinstance(event, ModelStopReason):
+ yield event
stop_reason, message, usage, metrics = event["stop"]
invocation_state.setdefault("request_state", {})
@@ -198,7 +192,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
time.sleep(current_delay)
current_delay = min(current_delay * 2, MAX_DELAY)
- yield EventLoopThrottleEvent(delay=current_delay, invocation_state=invocation_state)
+ yield EventLoopThrottleEvent(delay=current_delay)
else:
raise e
@@ -280,7 +274,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
-async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
+async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
"""Make a recursive call to event_loop_cycle with the current state.
This function is used when the event loop needs to continue processing after tool execution.
@@ -321,7 +315,7 @@ async def _handle_tool_execution(
cycle_span: Any,
cycle_start_time: float,
invocation_state: dict[str, Any],
-) -> AsyncGenerator[dict[str, Any], None]:
+) -> AsyncGenerator[TypedEvent, None]:
"""Handles the execution of tools requested by the model during an event loop cycle.
Args:
diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py
index 9999b77fc..701a3bac0 100644
--- a/src/strands/tools/executors/_executor.py
+++ b/src/strands/tools/executors/_executor.py
@@ -7,15 +7,16 @@
import abc
import logging
import time
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
from opentelemetry import trace as trace_api
from ...experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
from ...telemetry.metrics import Trace
from ...telemetry.tracer import get_tracer
+from ...types._events import ToolResultEvent, ToolStreamEvent, TypedEvent
from ...types.content import Message
-from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
+from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse
if TYPE_CHECKING: # pragma: no cover
from ...agent import Agent
@@ -33,7 +34,7 @@ async def _stream(
tool_results: list[ToolResult],
invocation_state: dict[str, Any],
**kwargs: Any,
- ) -> ToolGenerator:
+ ) -> AsyncGenerator[TypedEvent, None]:
"""Stream tool events.
This method adds additional logic to the stream invocation including:
@@ -113,12 +114,12 @@ async def _stream(
result=result,
)
)
- yield after_event.result
+ yield ToolResultEvent(after_event.result)
tool_results.append(after_event.result)
return
async for event in selected_tool.stream(tool_use, invocation_state, **kwargs):
- yield event
+ yield ToolStreamEvent(tool_use, event)
result = cast(ToolResult, event)
@@ -131,7 +132,8 @@ async def _stream(
result=result,
)
)
- yield after_event.result
+
+ yield ToolResultEvent(after_event.result)
tool_results.append(after_event.result)
except Exception as e:
@@ -151,7 +153,7 @@ async def _stream(
exception=e,
)
)
- yield after_event.result
+ yield ToolResultEvent(after_event.result)
tool_results.append(after_event.result)
@staticmethod
@@ -163,7 +165,7 @@ async def _stream_with_trace(
cycle_span: Any,
invocation_state: dict[str, Any],
**kwargs: Any,
- ) -> ToolGenerator:
+ ) -> AsyncGenerator[TypedEvent, None]:
"""Execute tool with tracing and metrics collection.
Args:
@@ -190,7 +192,8 @@ async def _stream_with_trace(
async for event in ToolExecutor._stream(agent, tool_use, tool_results, invocation_state, **kwargs):
yield event
- result = cast(ToolResult, event)
+ result_event = cast(ToolResultEvent, event)
+ result = result_event.tool_result
tool_success = result.get("status") == "success"
tool_duration = time.time() - tool_start_time
@@ -210,7 +213,7 @@ def _execute(
cycle_trace: Trace,
cycle_span: Any,
invocation_state: dict[str, Any],
- ) -> ToolGenerator:
+ ) -> AsyncGenerator[TypedEvent, None]:
"""Execute the given tools according to this executor's strategy.
Args:
diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py
index 7d5dd7fe7..767071bae 100644
--- a/src/strands/tools/executors/concurrent.py
+++ b/src/strands/tools/executors/concurrent.py
@@ -1,12 +1,13 @@
"""Concurrent tool executor implementation."""
import asyncio
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, AsyncGenerator
from typing_extensions import override
from ...telemetry.metrics import Trace
-from ...types.tools import ToolGenerator, ToolResult, ToolUse
+from ...types._events import TypedEvent
+from ...types.tools import ToolResult, ToolUse
from ._executor import ToolExecutor
if TYPE_CHECKING: # pragma: no cover
@@ -25,7 +26,7 @@ async def _execute(
cycle_trace: Trace,
cycle_span: Any,
invocation_state: dict[str, Any],
- ) -> ToolGenerator:
+ ) -> AsyncGenerator[TypedEvent, None]:
"""Execute tools concurrently.
Args:
diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py
index 55b26f6d3..60e5c7fa7 100644
--- a/src/strands/tools/executors/sequential.py
+++ b/src/strands/tools/executors/sequential.py
@@ -1,11 +1,12 @@
"""Sequential tool executor implementation."""
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, AsyncGenerator
from typing_extensions import override
from ...telemetry.metrics import Trace
-from ...types.tools import ToolGenerator, ToolResult, ToolUse
+from ...types._events import TypedEvent
+from ...types.tools import ToolResult, ToolUse
from ._executor import ToolExecutor
if TYPE_CHECKING: # pragma: no cover
@@ -24,7 +25,7 @@ async def _execute(
cycle_trace: Trace,
cycle_span: Any,
invocation_state: dict[str, Any],
- ) -> ToolGenerator:
+ ) -> AsyncGenerator[TypedEvent, None]:
"""Execute tools sequentially.
Args:
diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py
index 1bddc5877..cc2330a81 100644
--- a/src/strands/types/_events.py
+++ b/src/strands/types/_events.py
@@ -5,15 +5,18 @@
agent lifecycle.
"""
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
+
+from typing_extensions import override
from ..telemetry import EventLoopMetrics
from .content import Message
from .event_loop import Metrics, StopReason, Usage
from .streaming import ContentBlockDelta, StreamEvent
+from .tools import ToolResult, ToolUse
if TYPE_CHECKING:
- pass
+ from ..agent import AgentResult
class TypedEvent(dict):
@@ -27,6 +30,23 @@ def __init__(self, data: dict[str, Any] | None = None) -> None:
"""
super().__init__(data or {})
+ @property
+ def is_callback_event(self) -> bool:
+ """True if this event should trigger the callback_handler to fire."""
+ return True
+
+ def as_dict(self) -> dict:
+ """Convert this event to a raw dictionary for emitting purposes."""
+ return {**self}
+
+ def prepare(self, invocation_state: dict) -> None:
+ """Prepare the event for emission by adding invocation state.
+
+ This allows a subset of events to merge with the invocation_state without needing to
+ pass around the invocation_state throughout the system.
+ """
+ ...
+
class InitEventLoopEvent(TypedEvent):
"""Event emitted at the very beginning of agent execution.
@@ -38,9 +58,13 @@ class InitEventLoopEvent(TypedEvent):
invocation_state: The invocation state passed into the request
"""
- def __init__(self, invocation_state: dict) -> None:
+ def __init__(self) -> None:
"""Initialize the event loop initialization event."""
- super().__init__({"callback": {"init_event_loop": True, **invocation_state}})
+ super().__init__({"init_event_loop": True})
+
+ @override
+ def prepare(self, invocation_state: dict) -> None:
+ self.update(invocation_state)
class StartEvent(TypedEvent):
@@ -55,7 +79,7 @@ class StartEvent(TypedEvent):
def __init__(self) -> None:
"""Initialize the event loop start event."""
- super().__init__({"callback": {"start": True}})
+ super().__init__({"start": True})
class StartEventLoopEvent(TypedEvent):
@@ -67,7 +91,7 @@ class StartEventLoopEvent(TypedEvent):
def __init__(self) -> None:
"""Initialize the event loop processing start event."""
- super().__init__({"callback": {"start_event_loop": True}})
+ super().__init__({"start_event_loop": True})
class ModelStreamChunkEvent(TypedEvent):
@@ -79,7 +103,11 @@ def __init__(self, chunk: StreamEvent) -> None:
Args:
chunk: Incremental streaming data from the model response
"""
- super().__init__({"callback": {"event": chunk}})
+ super().__init__({"event": chunk})
+
+ @property
+ def chunk(self) -> StreamEvent:
+ return cast(StreamEvent, self.get("event"))
class ModelStreamEvent(TypedEvent):
@@ -97,13 +125,23 @@ def __init__(self, delta_data: dict[str, Any]) -> None:
"""
super().__init__(delta_data)
+ @property
+ def is_callback_event(self) -> bool:
+ # Only invoke a callback if we're non-empty
+ return len(self.keys()) > 0
+
+ @override
+ def prepare(self, invocation_state: dict) -> None:
+ if "delta" in self:
+ self.update(invocation_state)
+
class ToolUseStreamEvent(ModelStreamEvent):
"""Event emitted during tool use input streaming."""
def __init__(self, delta: ContentBlockDelta, current_tool_use: dict[str, Any]) -> None:
"""Initialize with delta and current tool use state."""
- super().__init__({"callback": {"delta": delta, "current_tool_use": current_tool_use}})
+ super().__init__({"delta": delta, "current_tool_use": current_tool_use})
class TextStreamEvent(ModelStreamEvent):
@@ -111,7 +149,7 @@ class TextStreamEvent(ModelStreamEvent):
def __init__(self, delta: ContentBlockDelta, text: str) -> None:
"""Initialize with delta and text content."""
- super().__init__({"callback": {"data": text, "delta": delta}})
+ super().__init__({"data": text, "delta": delta})
class ReasoningTextStreamEvent(ModelStreamEvent):
@@ -119,7 +157,7 @@ class ReasoningTextStreamEvent(ModelStreamEvent):
def __init__(self, delta: ContentBlockDelta, reasoning_text: str | None) -> None:
"""Initialize with delta and reasoning text."""
- super().__init__({"callback": {"reasoningText": reasoning_text, "delta": delta, "reasoning": True}})
+ super().__init__({"reasoningText": reasoning_text, "delta": delta, "reasoning": True})
class ReasoningSignatureStreamEvent(ModelStreamEvent):
@@ -127,7 +165,7 @@ class ReasoningSignatureStreamEvent(ModelStreamEvent):
def __init__(self, delta: ContentBlockDelta, reasoning_signature: str | None) -> None:
"""Initialize with delta and reasoning signature."""
- super().__init__({"callback": {"reasoning_signature": reasoning_signature, "delta": delta, "reasoning": True}})
+ super().__init__({"reasoning_signature": reasoning_signature, "delta": delta, "reasoning": True})
class ModelStopReason(TypedEvent):
@@ -150,6 +188,11 @@ def __init__(
"""
super().__init__({"stop": (stop_reason, message, usage, metrics)})
+ @property
+ @override
+ def is_callback_event(self) -> bool:
+ return False
+
class EventLoopStopEvent(TypedEvent):
"""Event emitted when the agent execution completes normally."""
@@ -171,18 +214,76 @@ def __init__(
"""
super().__init__({"stop": (stop_reason, message, metrics, request_state)})
+ @property
+ @override
+ def is_callback_event(self) -> bool:
+ return False
+
class EventLoopThrottleEvent(TypedEvent):
"""Event emitted when the event loop is throttled due to rate limiting."""
- def __init__(self, delay: int, invocation_state: dict[str, Any]) -> None:
+ def __init__(self, delay: int) -> None:
"""Initialize with the throttle delay duration.
Args:
delay: Delay in seconds before the next retry attempt
- invocation_state: The invocation state passed into the request
"""
- super().__init__({"callback": {"event_loop_throttled_delay": delay, **invocation_state}})
+ super().__init__({"event_loop_throttled_delay": delay})
+
+ @override
+ def prepare(self, invocation_state: dict) -> None:
+ self.update(invocation_state)
+
+
+class ToolResultEvent(TypedEvent):
+ """Event emitted when a tool execution completes."""
+
+ def __init__(self, tool_result: ToolResult) -> None:
+ """Initialize with the completed tool result.
+
+ Args:
+ tool_result: Final result from the tool execution
+ """
+ super().__init__({"tool_result": tool_result})
+
+ @property
+ def tool_use_id(self) -> str:
+ """The toolUseId associated with this result."""
+ return cast(str, cast(ToolResult, self.get("tool_result")).get("toolUseId"))
+
+ @property
+ def tool_result(self) -> ToolResult:
+ """Final result from the completed tool execution."""
+ return cast(ToolResult, self.get("tool_result"))
+
+ @property
+ @override
+ def is_callback_event(self) -> bool:
+ return False
+
+
+class ToolStreamEvent(TypedEvent):
+ """Event emitted when a tool yields sub-events as part of tool execution."""
+
+ def __init__(self, tool_use: ToolUse, tool_sub_event: Any) -> None:
+ """Initialize with tool streaming data.
+
+ Args:
+ tool_use: The tool invocation producing the stream
+ tool_sub_event: The yielded event from the tool execution
+ """
+ super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_event": tool_sub_event})
+
+ @property
+ def tool_use_id(self) -> str:
+ """The toolUseId associated with this stream."""
+ return cast(str, cast(ToolUse, self.get("tool_stream_tool_use")).get("toolUseId"))
+
+ @property
+ @override
+ def is_callback_event(self) -> bool:
+ return False
class ModelMessageEvent(TypedEvent):
@@ -198,7 +299,7 @@ def __init__(self, message: Message) -> None:
Args:
message: The response message from the model
"""
- super().__init__({"callback": {"message": message}})
+ super().__init__({"message": message})
class ToolResultMessageEvent(TypedEvent):
@@ -215,7 +316,7 @@ def __init__(self, message: Any) -> None:
Args:
message: Message containing tool results for conversation history
"""
- super().__init__({"callback": {"message": message}})
+ super().__init__({"message": message})
class ForceStopEvent(TypedEvent):
@@ -229,10 +330,12 @@ def __init__(self, reason: str | Exception) -> None:
"""
super().__init__(
{
- "callback": {
- "force_stop": True,
- "force_stop_reason": str(reason),
- # "force_stop_reason_exception": reason if reason and isinstance(reason, Exception) else MISSING,
- }
+ "force_stop": True,
+ "force_stop_reason": str(reason),
}
)
+
+
+class AgentResultEvent(TypedEvent):
+ def __init__(self, result: "AgentResult"):
+ super().__init__({"result": result})
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index a4a8af09a..a8561abe4 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -19,6 +19,7 @@
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel
from strands.session.repository_session_manager import RepositorySessionManager
from strands.telemetry.tracer import serialize
+from strands.types._events import EventLoopStopEvent, ModelStreamEvent
from strands.types.content import Messages
from strands.types.exceptions import ContextWindowOverflowException, EventLoopException
from strands.types.session import Session, SessionAgent, SessionMessage, SessionType
@@ -406,7 +407,7 @@ async def check_invocation_state(**kwargs):
assert invocation_state["agent"] == agent
# Return expected values from event_loop_cycle
- yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})}
+ yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})
mock_event_loop_cycle.side_effect = check_invocation_state
@@ -1144,12 +1145,12 @@ async def test_stream_async_returns_all_events(mock_event_loop_cycle, alist):
# Define the side effect to simulate callback handler being called multiple times
async def test_event_loop(*args, **kwargs):
- yield {"callback": {"data": "First chunk"}}
- yield {"callback": {"data": "Second chunk"}}
- yield {"callback": {"data": "Final chunk", "complete": True}}
+ yield ModelStreamEvent({"data": "First chunk"})
+ yield ModelStreamEvent({"data": "Second chunk"})
+ yield ModelStreamEvent({"data": "Final chunk", "complete": True})
# Return expected values from event_loop_cycle
- yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})}
+ yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})
mock_event_loop_cycle.side_effect = test_event_loop
mock_callback = unittest.mock.Mock()
@@ -1234,7 +1235,7 @@ async def check_invocation_state(**kwargs):
invocation_state = kwargs["invocation_state"]
assert invocation_state["some_value"] == "a_value"
# Return expected values from event_loop_cycle
- yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})}
+ yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})
mock_event_loop_cycle.side_effect = check_invocation_state
@@ -1366,7 +1367,7 @@ async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_trac
mock_get_tracer.return_value = mock_tracer
async def test_event_loop(*args, **kwargs):
- yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {})}
+ yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {})
mock_event_loop_cycle.side_effect = test_event_loop
diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py
index c76514ac8..68f9cc5ab 100644
--- a/tests/strands/event_loop/test_event_loop.py
+++ b/tests/strands/event_loop/test_event_loop.py
@@ -486,7 +486,7 @@ async def test_cycle_exception(
]
tru_stop_event = None
- exp_stop_event = {"callback": {"force_stop": True, "force_stop_reason": "Invalid error presented"}}
+ exp_stop_event = {"force_stop": True, "force_stop_reason": "Invalid error presented"}
with pytest.raises(EventLoopException):
stream = strands.event_loop.event_loop.event_loop_cycle(
diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py
index fd9548dae..fdd560b22 100644
--- a/tests/strands/event_loop/test_streaming.py
+++ b/tests/strands/event_loop/test_streaming.py
@@ -146,7 +146,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use)
],
)
def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, callback_args):
- exp_callback_event = {"callback": {**callback_args, "delta": event["delta"]}} if callback_args else {}
+ exp_callback_event = {**callback_args, "delta": event["delta"]} if callback_args else {}
tru_updated_state, tru_callback_event = strands.event_loop.streaming.handle_content_block_delta(event, state)
@@ -316,85 +316,71 @@ def test_extract_usage_metrics_with_cache_tokens():
],
[
{
- "callback": {
- "event": {
- "messageStart": {
- "role": "assistant",
- },
+ "event": {
+ "messageStart": {
+ "role": "assistant",
},
},
},
{
- "callback": {
- "event": {
- "contentBlockStart": {
- "start": {
- "toolUse": {
- "name": "test",
- "toolUseId": "123",
- },
+ "event": {
+ "contentBlockStart": {
+ "start": {
+ "toolUse": {
+ "name": "test",
+ "toolUseId": "123",
},
},
},
},
},
{
- "callback": {
- "event": {
- "contentBlockDelta": {
- "delta": {
- "toolUse": {
- "input": '{"key": "value"}',
- },
+ "event": {
+ "contentBlockDelta": {
+ "delta": {
+ "toolUse": {
+ "input": '{"key": "value"}',
},
},
},
},
},
{
- "callback": {
- "current_tool_use": {
- "input": {
- "key": "value",
- },
- "name": "test",
- "toolUseId": "123",
+ "current_tool_use": {
+ "input": {
+ "key": "value",
},
- "delta": {
- "toolUse": {
- "input": '{"key": "value"}',
- },
+ "name": "test",
+ "toolUseId": "123",
+ },
+ "delta": {
+ "toolUse": {
+ "input": '{"key": "value"}',
},
},
},
{
- "callback": {
- "event": {
- "contentBlockStop": {},
- },
+ "event": {
+ "contentBlockStop": {},
},
},
{
- "callback": {
- "event": {
- "messageStop": {
- "stopReason": "tool_use",
- },
+ "event": {
+ "messageStop": {
+ "stopReason": "tool_use",
},
},
},
{
- "callback": {
- "event": {
- "metadata": {
- "metrics": {
- "latencyMs": 1,
- },
- "usage": {
- "inputTokens": 1,
- "outputTokens": 1,
- "totalTokens": 1,
- },
+ "event": {
+ "metadata": {
+ "metrics": {
+ "latencyMs": 1,
+ },
+ "usage": {
+ "inputTokens": 1,
+ "outputTokens": 1,
+ "totalTokens": 1,
},
},
},
@@ -417,9 +403,7 @@ def test_extract_usage_metrics_with_cache_tokens():
[{}],
[
{
- "callback": {
- "event": {},
- },
+ "event": {},
},
{
"stop": (
@@ -463,80 +447,64 @@ def test_extract_usage_metrics_with_cache_tokens():
],
[
{
- "callback": {
- "event": {
- "messageStart": {
- "role": "assistant",
- },
+ "event": {
+ "messageStart": {
+ "role": "assistant",
},
},
},
{
- "callback": {
- "event": {
- "contentBlockStart": {
- "start": {},
- },
+ "event": {
+ "contentBlockStart": {
+ "start": {},
},
},
},
{
- "callback": {
- "event": {
- "contentBlockDelta": {
- "delta": {
- "text": "Hello!",
- },
+ "event": {
+ "contentBlockDelta": {
+ "delta": {
+ "text": "Hello!",
},
},
},
},
{
- "callback": {
- "data": "Hello!",
- "delta": {
- "text": "Hello!",
- },
+ "data": "Hello!",
+ "delta": {
+ "text": "Hello!",
},
},
{
- "callback": {
- "event": {
- "contentBlockStop": {},
- },
+ "event": {
+ "contentBlockStop": {},
},
},
{
- "callback": {
- "event": {
- "messageStop": {
- "stopReason": "guardrail_intervened",
- },
+ "event": {
+ "messageStop": {
+ "stopReason": "guardrail_intervened",
},
},
},
{
- "callback": {
- "event": {
- "redactContent": {
- "redactAssistantContentMessage": "REDACTED.",
- "redactUserContentMessage": "REDACTED",
- },
+ "event": {
+ "redactContent": {
+ "redactAssistantContentMessage": "REDACTED.",
+ "redactUserContentMessage": "REDACTED",
},
},
},
{
- "callback": {
- "event": {
- "metadata": {
- "metrics": {
- "latencyMs": 1,
- },
- "usage": {
- "inputTokens": 1,
- "outputTokens": 1,
- "totalTokens": 1,
- },
+ "event": {
+ "metadata": {
+ "metrics": {
+ "latencyMs": 1,
+ },
+ "usage": {
+ "inputTokens": 1,
+ "outputTokens": 1,
+ "totalTokens": 1,
},
},
},
@@ -588,29 +556,23 @@ async def test_stream_messages(agenerator, alist):
tru_events = await alist(stream)
exp_events = [
{
- "callback": {
- "event": {
- "contentBlockDelta": {
- "delta": {
- "text": "test",
- },
+ "event": {
+ "contentBlockDelta": {
+ "delta": {
+ "text": "test",
},
},
},
},
{
- "callback": {
- "data": "test",
- "delta": {
- "text": "test",
- },
+ "data": "test",
+ "delta": {
+ "text": "test",
},
},
{
- "callback": {
- "event": {
- "contentBlockStop": {},
- },
+ "event": {
+ "contentBlockStop": {},
},
},
{
diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py
index 7e0d6c2df..140537add 100644
--- a/tests/strands/tools/executors/test_concurrent.py
+++ b/tests/strands/tools/executors/test_concurrent.py
@@ -1,6 +1,8 @@
import pytest
from strands.tools.executors import ConcurrentToolExecutor
+from strands.types._events import ToolResultEvent, ToolStreamEvent
+from strands.types.tools import ToolUse
@pytest.fixture
@@ -12,21 +14,21 @@ def executor():
async def test_concurrent_executor_execute(
executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist
):
- tool_uses = [
+ tool_uses: list[ToolUse] = [
{"name": "weather_tool", "toolUseId": "1", "input": {}},
{"name": "temperature_tool", "toolUseId": "2", "input": {}},
]
stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state)
- tru_events = sorted(await alist(stream), key=lambda event: event.get("toolUseId"))
+ tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id)
exp_events = [
- {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
- {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
- {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]},
- {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]},
+ ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
+ ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
+ ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}),
+ ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}),
]
assert tru_events == exp_events
tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId"))
- exp_results = [exp_events[1], exp_events[3]]
+ exp_results = [exp_events[1].tool_result, exp_events[3].tool_result]
assert tru_results == exp_results
diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py
index edbad3939..56caa950a 100644
--- a/tests/strands/tools/executors/test_executor.py
+++ b/tests/strands/tools/executors/test_executor.py
@@ -6,6 +6,8 @@
from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
from strands.telemetry.metrics import Trace
from strands.tools.executors._executor import ToolExecutor
+from strands.types._events import ToolResultEvent, ToolStreamEvent
+from strands.types.tools import ToolUse
@pytest.fixture
@@ -32,18 +34,18 @@ def tracer():
async def test_executor_stream_yields_result(
executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist
):
- tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}}
+ tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
stream = executor._stream(agent, tool_use, tool_results, invocation_state)
tru_events = await alist(stream)
exp_events = [
- {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
- {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
+ ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
+ ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
]
assert tru_events == exp_events
tru_results = tool_results
- exp_results = [exp_events[-1]]
+ exp_results = [exp_events[-1].tool_result]
assert tru_results == exp_results
tru_hook_events = hook_events
@@ -73,11 +75,11 @@ async def test_executor_stream_yields_tool_error(
stream = executor._stream(agent, tool_use, tool_results, invocation_state)
tru_events = await alist(stream)
- exp_events = [{"toolUseId": "1", "status": "error", "content": [{"text": "Error: Tool error"}]}]
+ exp_events = [ToolResultEvent({"toolUseId": "1", "status": "error", "content": [{"text": "Error: Tool error"}]})]
assert tru_events == exp_events
tru_results = tool_results
- exp_results = [exp_events[-1]]
+ exp_results = [exp_events[-1].tool_result]
assert tru_results == exp_results
tru_hook_after_event = hook_events[-1]
@@ -98,11 +100,13 @@ async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results
stream = executor._stream(agent, tool_use, tool_results, invocation_state)
tru_events = await alist(stream)
- exp_events = [{"toolUseId": "1", "status": "error", "content": [{"text": "Unknown tool: unknown_tool"}]}]
+ exp_events = [
+ ToolResultEvent({"toolUseId": "1", "status": "error", "content": [{"text": "Unknown tool: unknown_tool"}]})
+ ]
assert tru_events == exp_events
tru_results = tool_results
- exp_results = [exp_events[-1]]
+ exp_results = [exp_events[-1].tool_result]
assert tru_results == exp_results
tru_hook_after_event = hook_events[-1]
@@ -120,18 +124,18 @@ async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results
async def test_executor_stream_with_trace(
executor, tracer, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist
):
- tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}}
+ tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
stream = executor._stream_with_trace(agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state)
tru_events = await alist(stream)
exp_events = [
- {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
- {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
+ ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
+ ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
]
assert tru_events == exp_events
tru_results = tool_results
- exp_results = [exp_events[-1]]
+ exp_results = [exp_events[-1].tool_result]
assert tru_results == exp_results
tracer.start_tool_call_span.assert_called_once_with(tool_use, cycle_span)
diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py
index d9b32c129..d4e98223e 100644
--- a/tests/strands/tools/executors/test_sequential.py
+++ b/tests/strands/tools/executors/test_sequential.py
@@ -1,6 +1,7 @@
import pytest
from strands.tools.executors import SequentialToolExecutor
+from strands.types._events import ToolResultEvent, ToolStreamEvent
@pytest.fixture
@@ -20,13 +21,13 @@ async def test_sequential_executor_execute(
tru_events = await alist(stream)
exp_events = [
- {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
- {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]},
- {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]},
- {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]},
+ ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
+ ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
+ ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}),
+ ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}),
]
assert tru_events == exp_events
tru_results = tool_results
- exp_results = [exp_events[1], exp_events[2]]
+ exp_results = [exp_events[1].tool_result, exp_events[3].tool_result]
assert tru_results == exp_results
From 47faba0911f00cecaff5cee8145530818a65c5e7 Mon Sep 17 00:00:00 2001
From: Laith Al-Saadoon <9553966+theagenticguy@users.noreply.github.com>
Date: Thu, 28 Aug 2025 08:29:20 -0700
Subject: [PATCH 060/104] feat: claude citation support with BedrockModel
(#631)
* feat: add citations to document content
* feat: addes citation types
* chore: remove uv.lock
* test: add letter.pdf for test-integ
* feat: working bedrock citations feature
* feat: fail early for citations with incompatible models
* fix: validates model ids with cross region inference ids
* Apply suggestion from @Unshure
Co-authored-by: Nick Clegg
* fix: addresses comments
* removes client exception handling
* moves citation into text elif
* puts relative imports back
* fix: tests failing
* Update src/strands/models/bedrock.py
Removes old comment
Co-authored-by: Nick Clegg
* Update src/strands/models/bedrock.py
Removes old comment
Co-authored-by: Nick Clegg
* Update imports in bedrock.py
Refactor imports in bedrock.py to include CitationsDelta.
* feat: typed citation events
---------
Co-authored-by: Nick Clegg
---
src/strands/agent/agent_result.py | 1 -
src/strands/event_loop/streaming.py | 16 +++
src/strands/models/bedrock.py | 29 +++-
src/strands/types/_events.py | 9 ++
src/strands/types/citations.py | 152 +++++++++++++++++++++
src/strands/types/content.py | 3 +
src/strands/types/media.py | 8 +-
src/strands/types/streaming.py | 37 +++++
tests/strands/event_loop/test_streaming.py | 29 ++++
tests_integ/conftest.py | 7 +
tests_integ/letter.pdf | Bin 0 -> 100738 bytes
tests_integ/models/test_model_bedrock.py | 49 ++++++-
tests_integ/test_max_tokens_reached.py | 2 +-
13 files changed, 332 insertions(+), 10 deletions(-)
create mode 100644 src/strands/types/citations.py
create mode 100644 tests_integ/letter.pdf
diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py
index e28e1c5b8..f3758c8d2 100644
--- a/src/strands/agent/agent_result.py
+++ b/src/strands/agent/agent_result.py
@@ -42,5 +42,4 @@ def __str__(self) -> str:
for item in content_array:
if isinstance(item, dict) and "text" in item:
result += item.get("text", "") + "\n"
-
return result
diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py
index 7507c6d75..efe094e5f 100644
--- a/src/strands/event_loop/streaming.py
+++ b/src/strands/event_loop/streaming.py
@@ -6,6 +6,7 @@
from ..models.model import Model
from ..types._events import (
+ CitationStreamEvent,
ModelStopReason,
ModelStreamChunkEvent,
ModelStreamEvent,
@@ -15,6 +16,7 @@
ToolUseStreamEvent,
TypedEvent,
)
+from ..types.citations import CitationsContentBlock
from ..types.content import ContentBlock, Message, Messages
from ..types.streaming import (
ContentBlockDeltaEvent,
@@ -140,6 +142,13 @@ def handle_content_block_delta(
state["text"] += delta_content["text"]
typed_event = TextStreamEvent(text=delta_content["text"], delta=delta_content)
+ elif "citation" in delta_content:
+ if "citationsContent" not in state:
+ state["citationsContent"] = []
+
+ state["citationsContent"].append(delta_content["citation"])
+ typed_event = CitationStreamEvent(delta=delta_content, citation=delta_content["citation"])
+
elif "reasoningContent" in delta_content:
if "text" in delta_content["reasoningContent"]:
if "reasoningText" not in state:
@@ -178,6 +187,7 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
current_tool_use = state["current_tool_use"]
text = state["text"]
reasoning_text = state["reasoningText"]
+ citations_content = state["citationsContent"]
if current_tool_use:
if "input" not in current_tool_use:
@@ -202,6 +212,10 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
elif text:
content.append({"text": text})
state["text"] = ""
+ if citations_content:
+ citations_block: CitationsContentBlock = {"citations": citations_content}
+ content.append({"citationsContent": citations_block})
+ state["citationsContent"] = []
elif reasoning_text:
content_block: ContentBlock = {
@@ -275,6 +289,8 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[T
"text": "",
"current_tool_use": {},
"reasoningText": "",
+ "signature": "",
+ "citationsContent": [],
}
state["content"] = state["message"]["content"]
diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py
index ace35640a..0fe332a47 100644
--- a/src/strands/models/bedrock.py
+++ b/src/strands/models/bedrock.py
@@ -7,7 +7,7 @@
import json
import logging
import os
-from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
+from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast
import boto3
from botocore.config import Config as BotocoreConfig
@@ -18,8 +18,11 @@
from ..event_loop import streaming
from ..tools import convert_pydantic_to_tool_spec
from ..types.content import ContentBlock, Message, Messages
-from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
-from ..types.streaming import StreamEvent
+from ..types.exceptions import (
+ ContextWindowOverflowException,
+ ModelThrottledException,
+)
+from ..types.streaming import CitationsDelta, StreamEvent
from ..types.tools import ToolResult, ToolSpec
from .model import Model
@@ -510,7 +513,7 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
yield {"messageStart": {"role": response["output"]["message"]["role"]}}
# Process content blocks
- for content in response["output"]["message"]["content"]:
+ for content in cast(list[ContentBlock], response["output"]["message"]["content"]):
# Yield contentBlockStart event if needed
if "toolUse" in content:
yield {
@@ -553,6 +556,24 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
}
}
}
+ elif "citationsContent" in content:
+ # For non-streaming citations, emit text and metadata deltas in sequence
+ # to match streaming behavior where they flow naturally
+ if "content" in content["citationsContent"]:
+ text_content = "".join([content["text"] for content in content["citationsContent"]["content"]])
+ yield {
+ "contentBlockDelta": {"delta": {"text": text_content}},
+ }
+
+ for citation in content["citationsContent"]["citations"]:
+ # Then emit citation metadata (for structure)
+
+ citation_metadata: CitationsDelta = {
+ "title": citation["title"],
+ "location": citation["location"],
+ "sourceContent": citation["sourceContent"],
+ }
+ yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}}
# Yield contentBlockStop event
yield {"contentBlockStop": {}}
diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py
index cc2330a81..1a7f48d4b 100644
--- a/src/strands/types/_events.py
+++ b/src/strands/types/_events.py
@@ -10,6 +10,7 @@
from typing_extensions import override
from ..telemetry import EventLoopMetrics
+from .citations import Citation
from .content import Message
from .event_loop import Metrics, StopReason, Usage
from .streaming import ContentBlockDelta, StreamEvent
@@ -152,6 +153,14 @@ def __init__(self, delta: ContentBlockDelta, text: str) -> None:
super().__init__({"data": text, "delta": delta})
+class CitationStreamEvent(ModelStreamEvent):
+ """Event emitted during citation streaming."""
+
+ def __init__(self, delta: ContentBlockDelta, citation: Citation) -> None:
+ """Initialize with delta and citation content."""
+ super().__init__({"callback": {"citation": citation, "delta": delta}})
+
+
class ReasoningTextStreamEvent(ModelStreamEvent):
"""Event emitted during reasoning text streaming."""
diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py
new file mode 100644
index 000000000..b0e28f655
--- /dev/null
+++ b/src/strands/types/citations.py
@@ -0,0 +1,152 @@
+"""Citation type definitions for the SDK.
+
+These types are modeled after the Bedrock API.
+"""
+
+from typing import List, Union
+
+from typing_extensions import TypedDict
+
+
+class CitationsConfig(TypedDict):
+ """Configuration for enabling citations on documents.
+
+ Attributes:
+ enabled: Whether citations are enabled for this document.
+ """
+
+ enabled: bool
+
+
+class DocumentCharLocation(TypedDict, total=False):
+ """Specifies a character-level location within a document.
+
+ Provides precise positioning information for cited content using
+ start and end character indices.
+
+ Attributes:
+ documentIndex: The index of the document within the array of documents
+ provided in the request. Minimum value of 0.
+ start: The starting character position of the cited content within
+ the document. Minimum value of 0.
+ end: The ending character position of the cited content within
+ the document. Minimum value of 0.
+ """
+
+ documentIndex: int
+ start: int
+ end: int
+
+
+class DocumentChunkLocation(TypedDict, total=False):
+ """Specifies a chunk-level location within a document.
+
+ Provides positioning information for cited content using logical
+ document segments or chunks.
+
+ Attributes:
+ documentIndex: The index of the document within the array of documents
+ provided in the request. Minimum value of 0.
+ start: The starting chunk identifier or index of the cited content
+ within the document. Minimum value of 0.
+ end: The ending chunk identifier or index of the cited content
+ within the document. Minimum value of 0.
+ """
+
+ documentIndex: int
+ start: int
+ end: int
+
+
+class DocumentPageLocation(TypedDict, total=False):
+ """Specifies a page-level location within a document.
+
+ Provides positioning information for cited content using page numbers.
+
+ Attributes:
+ documentIndex: The index of the document within the array of documents
+ provided in the request. Minimum value of 0.
+ start: The starting page number of the cited content within
+ the document. Minimum value of 0.
+ end: The ending page number of the cited content within
+ the document. Minimum value of 0.
+ """
+
+ documentIndex: int
+ start: int
+ end: int
+
+
+# Union type for citation locations
+CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation]
+
+
+class CitationSourceContent(TypedDict, total=False):
+ """Contains the actual text content from a source document.
+
+ Contains the actual text content from a source document that is being
+ cited or referenced in the model's response.
+
+ Note:
+ This is a UNION type, so only one of the members can be specified.
+
+ Attributes:
+ text: The text content from the source document that is being cited.
+ """
+
+ text: str
+
+
+class CitationGeneratedContent(TypedDict, total=False):
+ """Contains the generated text content that corresponds to a citation.
+
+ Contains the generated text content that corresponds to or is supported
+ by a citation from a source document.
+
+ Note:
+ This is a UNION type, so only one of the members can be specified.
+
+ Attributes:
+ text: The text content that was generated by the model and is
+ supported by the associated citation.
+ """
+
+ text: str
+
+
+class Citation(TypedDict, total=False):
+ """Contains information about a citation that references a source document.
+
+ Citations provide traceability between the model's generated response
+ and the source documents that informed that response.
+
+ Attributes:
+ location: The precise location within the source document where the
+ cited content can be found, including character positions, page
+ numbers, or chunk identifiers.
+ sourceContent: The specific content from the source document that was
+ referenced or cited in the generated response.
+ title: The title or identifier of the source document being cited.
+ """
+
+ location: CitationLocation
+ sourceContent: List[CitationSourceContent]
+ title: str
+
+
+class CitationsContentBlock(TypedDict, total=False):
+ """A content block containing generated text and associated citations.
+
+ This block type is returned when document citations are enabled, providing
+ traceability between the generated content and the source documents that
+ informed the response.
+
+ Attributes:
+ citations: An array of citations that reference the source documents
+ used to generate the associated content.
+ content: The generated content that is supported by the associated
+ citations.
+ """
+
+ citations: List[Citation]
+ content: List[CitationGeneratedContent]
diff --git a/src/strands/types/content.py b/src/strands/types/content.py
index 790e9094c..c3eddca4d 100644
--- a/src/strands/types/content.py
+++ b/src/strands/types/content.py
@@ -10,6 +10,7 @@
from typing_extensions import TypedDict
+from .citations import CitationsContentBlock
from .media import DocumentContent, ImageContent, VideoContent
from .tools import ToolResult, ToolUse
@@ -83,6 +84,7 @@ class ContentBlock(TypedDict, total=False):
toolResult: The result for a tool request that a model makes.
toolUse: Information about a tool use request from a model.
video: Video to include in the message.
+ citationsContent: Contains the citations for a document.
"""
cachePoint: CachePoint
@@ -94,6 +96,7 @@ class ContentBlock(TypedDict, total=False):
toolResult: ToolResult
toolUse: ToolUse
video: VideoContent
+ citationsContent: CitationsContentBlock
class SystemContentBlock(TypedDict, total=False):
diff --git a/src/strands/types/media.py b/src/strands/types/media.py
index 29b89e5c6..69cd60cf3 100644
--- a/src/strands/types/media.py
+++ b/src/strands/types/media.py
@@ -5,10 +5,12 @@
- Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html
"""
-from typing import Literal
+from typing import Literal, Optional
from typing_extensions import TypedDict
+from .citations import CitationsConfig
+
DocumentFormat = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
"""Supported document formats."""
@@ -23,7 +25,7 @@ class DocumentSource(TypedDict):
bytes: bytes
-class DocumentContent(TypedDict):
+class DocumentContent(TypedDict, total=False):
"""A document to include in a message.
Attributes:
@@ -35,6 +37,8 @@ class DocumentContent(TypedDict):
format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
name: str
source: DocumentSource
+ citations: Optional[CitationsConfig]
+ context: Optional[str]
ImageFormat = Literal["png", "jpeg", "gif", "webp"]
diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py
index 9c99b2108..dcfd541a8 100644
--- a/src/strands/types/streaming.py
+++ b/src/strands/types/streaming.py
@@ -9,6 +9,7 @@
from typing_extensions import TypedDict
+from .citations import CitationLocation
from .content import ContentBlockStart, Role
from .event_loop import Metrics, StopReason, Usage
from .guardrails import Trace
@@ -57,6 +58,41 @@ class ContentBlockDeltaToolUse(TypedDict):
input: str
+class CitationSourceContentDelta(TypedDict, total=False):
+ """Contains incremental updates to source content text during streaming.
+
+ Allows clients to build up the cited content progressively during
+ streaming responses.
+
+ Attributes:
+ text: An incremental update to the text content from the source
+ document that is being cited.
+ """
+
+ text: str
+
+
+class CitationsDelta(TypedDict, total=False):
+ """Contains incremental updates to citation information during streaming.
+
+ This allows clients to build up citation data progressively as the
+ response is generated.
+
+ Attributes:
+ location: Specifies the precise location within a source document
+ where cited content can be found. This can include character-level
+ positions, page numbers, or document chunks depending on the
+ document type and indexing method.
+ sourceContent: The specific content from the source document that was
+ referenced or cited in the generated response.
+ title: The title or identifier of the source document being cited.
+ """
+
+ location: CitationLocation
+ sourceContent: list[CitationSourceContentDelta]
+ title: str
+
+
class ReasoningContentBlockDelta(TypedDict, total=False):
"""Delta for reasoning content block in a streaming response.
@@ -83,6 +119,7 @@ class ContentBlockDelta(TypedDict, total=False):
reasoningContent: ReasoningContentBlockDelta
text: str
toolUse: ContentBlockDeltaToolUse
+ citation: CitationsDelta
class ContentBlockDeltaEvent(TypedDict, total=False):
diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py
index fdd560b22..ce12b4e98 100644
--- a/tests/strands/event_loop/test_streaming.py
+++ b/tests/strands/event_loop/test_streaming.py
@@ -164,12 +164,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up
"current_tool_use": {"toolUseId": "123", "name": "test", "input": '{"key": "value"}'},
"text": "",
"reasoningText": "",
+ "citationsContent": [],
},
{
"content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {"key": "value"}}}],
"current_tool_use": {},
"text": "",
"reasoningText": "",
+ "citationsContent": [],
},
),
# Tool Use - Missing input
@@ -179,12 +181,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up
"current_tool_use": {"toolUseId": "123", "name": "test"},
"text": "",
"reasoningText": "",
+ "citationsContent": [],
},
{
"content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}],
"current_tool_use": {},
"text": "",
"reasoningText": "",
+ "citationsContent": [],
},
),
# Text
@@ -194,12 +198,31 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up
"current_tool_use": {},
"text": "test",
"reasoningText": "",
+ "citationsContent": [],
},
{
"content": [{"text": "test"}],
"current_tool_use": {},
"text": "",
"reasoningText": "",
+ "citationsContent": [],
+ },
+ ),
+ # Citations
+ (
+ {
+ "content": [],
+ "current_tool_use": {},
+ "text": "",
+ "reasoningText": "",
+ "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}],
+ },
+ {
+ "content": [],
+ "current_tool_use": {},
+ "text": "",
+ "reasoningText": "",
+ "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}],
},
),
# Reasoning
@@ -210,6 +233,7 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up
"text": "",
"reasoningText": "test",
"signature": "123",
+ "citationsContent": [],
},
{
"content": [{"reasoningContent": {"reasoningText": {"text": "test", "signature": "123"}}}],
@@ -217,6 +241,7 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up
"text": "",
"reasoningText": "",
"signature": "123",
+ "citationsContent": [],
},
),
# Reasoning without signature
@@ -226,12 +251,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up
"current_tool_use": {},
"text": "",
"reasoningText": "test",
+ "citationsContent": [],
},
{
"content": [{"reasoningContent": {"reasoningText": {"text": "test"}}}],
"current_tool_use": {},
"text": "",
"reasoningText": "",
+ "citationsContent": [],
},
),
# Empty
@@ -241,12 +268,14 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_up
"current_tool_use": {},
"text": "",
"reasoningText": "",
+ "citationsContent": [],
},
{
"content": [],
"current_tool_use": {},
"text": "",
"reasoningText": "",
+ "citationsContent": [],
},
),
],
diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py
index 61c2bf9a1..26453e1f7 100644
--- a/tests_integ/conftest.py
+++ b/tests_integ/conftest.py
@@ -22,6 +22,13 @@ def yellow_img(pytestconfig):
return fp.read()
+@pytest.fixture
+def letter_pdf(pytestconfig):
+ path = pytestconfig.rootdir / "tests_integ/letter.pdf"
+ with open(path, "rb") as fp:
+ return fp.read()
+
+
## Async
diff --git a/tests_integ/letter.pdf b/tests_integ/letter.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..d8c59f749219814f9e76c69393de15719d7f37ae
GIT binary patch
literal 100738
zcmagF1CV9GvNk$x+um*4wr$(CZM&z9-92sFwlV$B9K3s8yz~CMV@K>*wJI}L
z;>)VsS&LLbM2wb^jvbD4Z~xak95fpXfB|4{WCh2=LoZ`#XKLtV>1ApHV5FA;FfuT(
zGjq_(1K8N;MFA`<9E>dVasWm~dX>Mon3)*qB>>t0Rt5$D6FZA8A0M2loylKLfdBUg
z4w~aXWQdp=+nbmw7&_TH|B)zcZ{uof=L}$>S8y^lu{3tEcLFdn@$u1%S=zXmI?;>S
z7`mALlZ0MUM4LyLg;7kDQG}C8h=GZnNr;U>L{N-bP=tk3RFH*Lgn^TfM}&i&Q-oQR
znU#}4kd2vzgHu>YRE&vHl#`8Jn2nuLh)9wrBJ|#PAS9x6Its}W
z_?{j(GSLvg4nBe#$wUt*!C(#?={j(*NETKm2KHFcje$W)7lVWTw-?}`IXVBw$N$Z+04653e=E<($waRV
zU}XE_g#XaTz`(%Fz#z)Zz|dgFa}3ZM@C{giQ4MhP@=YZ_pTQntfcl8Ks#wNalbu=4
zkbRf0x3_mlbqyVR1N)oObQz|bL4PVI3fZ?vveKqlz||RO{I`e|XoPq~G%~d1Z(%I?
zS{UGF#^C4>t*XPYli;whE4JUHDw!C_-asq~0q9`X9AmvBy@+=^4~AQl7&;KuTt9D3
zcDfxb=BP^WbHctuBx{&`F)(zkjCO+UobkU_(Ff=!n{S&(Ms?Mfz_I9TK7;px#
z{gdb)@gV$%i;cbc{~+NX^#2kgY;WgcYUkpt&BpPUDS+*-QfSrR`mtOV{Yz+U7{+-Ll$n=Mn>F>ij
z{gL(O`Y$EDlBu)3tCO**Gl1itVXS2D@&{LdKVj@-Y5V_S`43(Gs{0>`u>WE7FGc=K
z1c_NXIlKJ91?&G4MJ#NLe{%)PKY=1|WTot4`=1fjm}afK&Vd@{-}f8BH)OSVFrpWwh(RI|Q4((;rj3gLuIi!6GKVT4Q?
zDpE==gWX)QTT32;5GI-?@7(s3Rm~OrxMrg#x7A!ej=|qSPQR#)KbRajY08x})MT5i
z6PIK`W)nZGnIytVxRxk`HYa@#d5_GgXi6V{UE*Os&!dUTGH%uoQ6hg|F5{_~9a}E(
z!$@k|G*>ZIN+}+z+>(w|wq8^&`9)9(iM~7mL4(+nTgq4l*Hkif?~!JScX#!{mJkMQ
z-?pes)b?x!2u~?3?j&9AM4qdd`Ajiy`ZVbG4_ZNBZk5TxM3Exo
zt>XwwXnB-D1r^G@n6rno5WfC*)H+_me)PRJy>o}^U?V6hR&}nQsGE=n6J``}ZRg60
zR;8}i)^O6g(r^=nGvZ<5Gh)YiX~h70?x>Jsw!zOG-|cy>WlT&((idk7JT
z-`_m3-{eSazM*E=5#`U6JZqVrj@RzdA7<8kfasDA2#`}r87x%uN|P|0Q5ah#=V_Yh
zjQe?ft7FRfsat-|Vd$AyD0G+dbA2g11gqRP;$G-B7WjW915j
z+hEbe1(oRZh|)_*ZpU3b--t0QZ8fcr1Dx?9@Pyu@?fO#edoF~FF3+uY};DTVbTKK3i;eNXs&eSesNIq|M(hlk{>bFE6(PK;m1+}&T`{B^0e
z3%|I`@N0Z(-njTr*m2&=vUGioJ8lZL^sl>ysxP^4jORU!g>WLfUV#ocK;Mv|$q%gQ
zb*>>gp&i76Q9m@3RbKFOp!g{yqa-W=z~P#2OyaxFwMP_m$@3(Xm{xQ1^7AvUx%0I!
z2*?^*3bL6%K3ILpxPHxoJ-`;2105RkwS!Y&C!=4K-GFL7qE?|v+643A1O37iAWH+e
zGagF}Y(&O^{GbSX7K(mxqhC)7b?JQSbXVYF*|eCJ_8ZS#$iz4-tIS>OcY|yF$(Qdn
zwaA&jYJqKY4Z$9l*X#DW;i&CxD=S-yZ@*V4DiTW)=LU2GZ1xTT5n(S3Eg%8ri?^04j8QOi)j>cW5b&%)Fn^;PxUkW?a7zz=akJ<(*MW!M
zZ|^v41I6u_I&^{uKKB-a3;EW)Y7Q9SaGZJSaCd^nnN?cHv@|dG&Q>vQfyGs9y&UdX
z&@Qk{xXazQrk5zbg$`*NyN4e7Ik=UR
ztFpwYgrua+jZh`>)Ae<~m?m#uTUDNmy1l-?=>vg(D?D+t_;1VT-=*yTe_8nVeu{JiAoGB~pMqbS*_GytZd0h@Ia|9a9Ab~@NV1(U_#m-mxou>yYf
zt7Bfb_xtPN>!0s^UVP8C-tU)6p3fV}@s>P(TX}+bmUyk*yp-*h(rf>uyYpq)aeLGE
zpu+th)-}!#Mza;J^zkj79_vr4i|X_piOZS1FXql5>AeY>q_G(%`bVn_$Lj0B%=UdpX`@EZa+V63>
z=fZZ)YKpxTWY-X#35HG%@)CwFvd$d`#a*`RN+yxKRXctt!kjgP4yQI`_F|ffAw&Ll
z4@I~HI?wt8HX?+fi;g5P25pOAy==nrq{^aZ_~R!X%9%q1-0_d{Dxp$k0_L-QKTv{-
z5Q(_XQu8qYj#YRg4+)x4L1NX*7Zj-l;2{>~B=C$@g_s%d(6#cX;JxB5u@qi^W5Oww
z>zEcqfU)C3g}!^uAUb=h0DUori6byLB>MAU8Pp|Y0&2fhL2_i{u2f?}rMZw@Ks@ak
zL)o{hIKB{CB{*()1z)T(%tfMCmm_Fb&|)V3J-7Y9;X`4;swJbE=^vTialKOagsR?d}5
z;=HZjkeX!OrTu;sVm#5j
z`adPh)u43RvdT33V)BGw*WY}8Z(stw)Yh)1aDYD1;AX>8qwE);VMU@xyJ@J)B8!;_kS^0g1#&(ZK
z4XfSm3Jv64G(4~|>2a{a`{8;bX%oQ7H8b3b(Dn#gl6K*4vyJf|8Vhs
z5xh)m-mf`BS7fb`ZbQ}6Hy$XI4WdE$T^-0n&}F
zjT-8sCBz6w$){jP|2UPk(mau
z=tVE_ot==Cx#JxXRJ-bwK=X&BthID!yiM09u5v2d>D5|E$n++O%_RGZ)uk(<=se;t
zE@g|T>eyK8I5R3f>}B<>9aq4^3VRYSc>GG7w@zQJC4%%8O4JN{)N~btjzm(=^6fIg
zNZts)ER5B?Mm>sTxsWRA>o+^)UP@;#
z?qcHcV7>4Oo!jA+ER<0ag~3?17m8*(9wef|+e{Pey6jP#CKL#92AE_M1lCYC(l;XrL2tw-tlw64|m2E^$Oif?<(kfiq^F%r4Q{KEaJ5Vx2j=csMu{$oF
zo?D?Re4d-he1)a*_nc*W%CF}|r-_gSCwIr%J5Scmq*$c`&ir2OBusi0cgZ`TG?K_F
z&8VNAIeW3a&I_Cn$zcmF3Gb|rf*euIo`U0=ELPUr@Kb^~ut8BK2Eh9iNpvQNhdUW_
zufedHK{MetH%OvHd!{pZWrz4ZakNm>8;H2odKfgVeT2D-G$hIO3ge*|McN$!c8oa)
z0#W)niS~5@B3~XCP8#fUS4A6m>WiK_!L?h-R2%)HM!`yhwX7<%jP<
zD7z3|66}R-h4U|~_dqP)U_S(e8-Pmjv=fqL5@MCk9udC#HHCsH3&;%j$86Qyx-xo$
z&ZtM<$cJNotpm6YqMn4yL$B$}BXOvBE75XjT=#!1nzuzx7t2|At+=@IuuP(s;nE6q
zAk)bRt2s1GtD^UNm(_Vv9k$%q#o+}2B7KGS;ZE~FDc>8qO+jCbiQfcToi69jd4YyF
zPzT5k8crxtjXe0fNc=OL^I0|+PplGf5=haQ)tUwtqn(*vxVp*6oL*mL0(x{ImX7w_
z+z2qOX*%9~^VE0Se{>pXwX_9OX$5!gNrN13MYIu0y?A4l+gV`G&WYU~$hjWKCMU|v
zQD{>Dyh!gdr>D|PA#1_{pxO8|A0pKWHJn&%E2dUD6cIN=ZhT_6I4Y|r-+g|7V%Gnv
zY}@8FA*4i2J$>yx<7eyr)nJ!qv7*|nqzmphgZ!dob8i6ENN8RNlfQcxvgR|hSz#H~
zPr-$BbKIu^uVOrenVA=i<)e)X9wx7Lmx)Na{n&D2j%PIpQTM#I?uf%Dpu9S`-UAH#
z&G~nz?0-Qu``;4se_hSlI5?Tl{!f&rsCuiQKBMKXnz{)(qHBmd
zq)t{xk%_Q$3ot+j1tkF}q)EyIRZ_!sHCBdnBcLr7BcRbmlnE#fi>rdfJSl*Gpo&IF
z`7VmHYX9Udf_zLonm#g;GiT+r)%vpdYCGP_c=zpj|2*vYkmy*)`yveS3a@>?n`NL#5e#;2;BFZU
zMNB@Zo--zFUx58Cn3iCuRj^!kbTU`}Fcv9^RWat*-K{6|kW~a7v;1?RY(X;&RyIly
zEwf#!RDS+c-Nayy8dw919e+nx)Duwy={?#V$Vs3I~cGVLSjV6jT)XiKu@zQlSlr8Is6?BvA&j
ze*!j$0Z9a3Py`x7Bm`+!1n5*sEzuvnEs{tCT(PiB`QfJqNiO*`y8BvegYr6F?R(-~
zzMu4MQe={kWKZKZ-@@Q>
zLANQ%QVR--Na
zPE^PEvOjhaepw(8?zBRdqr2m8Igj!#^N`?zNq6AWwa7QQK?)YW*U^$&$D295Gr!U3
z2JI$oZ{Los%rM*y+DS5NiuBG;Tgxup0v2j
zKqr0=l@$XgW}Q+#dEitAGVIpNr=9YZ!w&`oJAOTw&F}sU+Ztpu9V*JmoLWG=%m_Ss
z*L>7)7@a86U+Zy-ER2J=v7_tp+aKs)`^FV<9QI7VSOE@M&tA5`!d}lhB$zDC4l~O`
z$L{o~Fyr@HNi(;NxYIzrh_08DRogD4db_!PozRxDPAj82Ws6VhkKBymS#T8`{~NZ%
zSYM-7JKa8mokmB@0BX&cYjOK@?p*3bqE86cm>@VB6a8=K)8|IT3lRk41PX;;)9$?CLK|{(fiU_ZbHxq=SKEUY<@vnsSb}v&6Q*}t
z#caQW-vkN7(Xna~gWZWlVk?O$6;BpNg0Gs(=3wC_(i#yU>%qq~la*
zRd;OVu#KDphruuaLBCb?!*HJ8&{!|FXc>$ZyVY?Syev~UyDco5V#pp(8N^Xu7Nua&
z1_WmY5sKj{BgpZUP>PedBAgPrfH!L
zSoO*P5(T}1g@xknZvQH(0q?l=x}=#_uXm0!JU>6Nbn*ZS#^w~Q<%ppPIUTHy0l}^J
zA#x{#sferqa#PW3z&*do@awi=oC8)|J>|zeoA*h8xXk!nBb;PZ-6_EPqvfrdn!a+@
z+)FFh(K%`E2*d9ar)T5XH^cuqYKmKj`}Zv^+1pq<@SK@$3QX$z_QiXe{)+mRcEA7U
zw)tFm=(e~5tLD-o^j-zIQ7&rL14q7jP?~?4zK?2B+rdtOQuUoWxzuW=^2k}c9Q}hJ
z`k3dtBbPc+o2ttEeL}}m=^XPWjTyY1KMenu_kzlRdDF)UlZKKcTTbBfmC%Fh%MUG^
zoQZSHuMO8c13`iHR{E_NM)gw+GIvrR#yn??!Ag}1YDk_
zguLTsPuDFi;*L(H!q>gS@i(rk@uqhFP6I>AIAHgf5njx2D2=FB#3;iA2s?oFd5c5o
zbV?4_f@w{%fdL@Si$EAPJy5w9r@prs%H->yL#yxy$FEvQW`a!c;PqH8V|NRlC5z7B
z%!o(eH3^XrD#8einc&eJu+>dXH6$pi)=C2)s4E9oq@W;bXw-e$h#{C^n3644M#iY0
z<*_5~&D+`8-5}(boS(A|{72dL$KEx+-g&M!zT`X{&IxRH`UaFR1-nlB%O#cj`H`1~
z#7hS6hH1P_54VDc!b8H;VND4Y)?EOezWqzzjf!eLyZi>iTyq_vB8`~xioP~pi|?<=
zYoKyyW#h=Uq;vI5187P+fk&KC@A5~LI_`)V6<9=%W04JuMSeWI|ofaNYlmZU~
z3X{3Wg51jV#~-*w9-HryYydUogZgk8Frh}?R|fvg-sA7)fBGnS7!+%9u#QGPy?M}^
z<>fL9)xLV(6xsw~qxyl6CS*3Li@nz)*@z|%Q^{lVy?Dt;bjAQ^hySq`Cw
z0$U=3^Ei#n38`Sly*QgTSn{a>WT9!$9Xd=Mbz^aOJUZT?qf=LT82CKt0?Zh%iLQ{L
zWM}@;Pq1zqpz#N`lC}H
zHcsyzWdSY?f$Bvv2xw#*jMe-Y51n}QKv1ZQA>uKDl*-)!(F<>Vm<4<{VNZgCDX!zl0>P8Z^wu>D$0T!D6K`DZ9Z
zF}v@J8*~ej4qDopA~@8WEj;rZm^TX9}-vVRLcDv3vuMNQQ2wO
z_y_x*Enar|?8Lz%VP2vw?Yh>t@wF3NOU=wmMBGXA
zfR5gXJ;W0pf5s_Nw=1d62qoImCNyKU4jx+3g17*Za2^lYZ~M_`
z@|xI)HDaj;tMURJbrVJ>X~X4VYQGEMZw@bRG0+iKR7D$`-u!iA>sVsbO-#y
zGOgiz@}&74?Nip?98Npe;Q9kl0Kq9lYs)*5E58nuav4F$zCG2IEOp}ft>nOM_%VwV
z#FoAAU<(6oTGDL1<9lR2!TX*u3*KdTIKh5WCye68VaUA%#Q@@vStG2bgM8Ma*T*~x
zkHy@5jVfh&>#bWi1iZWI5PcASaQ+9s**#Fpw`Y)gi_V7mLtfl>
z;q)BG#T&o)2~^c|_+;%k8bQWJnRgS>!xr*79s)qjoDlG83e`oMTewJ-6zU~a)O^0O
za}!6aK4@w2aEqQ0SHyeFc7!8!6)|-Qet#$TB~7z>UR*@ud`Tlm2bD?35DbSYA2mfn
zFU2(sfs<>gzRv#$?Z91@=Q>@lZxDt_4E;dOVD$1RU35qyaov!r
z@MF-fR2gZ89a>wP?w-oJxqbN0E=G(oV{Wa6lz-)^_{HgguAp7q*kJ<33;gOP4*s{8
zS4Aa3MS~hel6D7Pc>%7k%t>p&s|zsR)9lnTmnUSKlD6X}mjM2av*K6UXt&=~S8d_2
z#|g+vl@D)yP^hXCZP=2>#ln>$ax9vpxSy>WbXsy=gK=ft8|818`VaJE-8^M_$MB+h
zXRtGuwynYMG?ORIE!%qLev)pVLK4bvVpryU@_N9($C1Hh#mCJN_-o4T8-5&8BsMPF
z$}wTsOlVQhB#tA#X?rp>H1Yi+bc9=c|LVjcc8xfAfPYWT+$l5`--GwYPotw6B~H%{
z<{p1?hM=B^*L*kke*2c#>Q;_oi?ph&+6F$6+86=}^t^%jhL4#sDa`rZt-7M~fN1R!
zc?`?jZbcD@ZLS8hJpeSyvimH1J7E-871@BZFJp7yV_{xrcFcPh8Q0mtW1=wMd&ghj
zS(?qaQdYXoX8HiG%`_ZrMI2uY4^Mg<5(+`CB@2)B{j#b3M^L>`bz8~q#UV&JjLR(K
zQBw)#J~A8fl(O=0bBqXGD{p;eL*biGi(RPjwN3Rq`}!jG;==dPSW71tE-kcc
zn#r7_00qUg%f@DnS*`^jpMw0LfP7RSARtdv38&;6F@S_#J|aFLAyMBAB4QvgwCV{>
zsT5j;@Xw4s)njSB6{&Z*Yl(hIUn!qanY>U$Uo0dYQibkNaOgi&Elx}6b)LdV
zuj4|fT=G`fj(xq_+EtV*sC5-97IPNnFdlEaXX1p*;^Zm|l5;zovF>kp^q*-H#bbYp
zAYCOY^>B=E-xKHQhPx@~VG$)mTtE`TitbpaB(0a~Ta#m8>b
z^#n->PhMJKaQ3|GOYki3d8B6(zZDfl3#tjaNrz~!LQvhtLJ{ZJ8(vh-ZUdsH5}$t9
z$W@7M3)24ZvM>vkh|e{D#||mZrbJEaSL5MpVVxpNk6uhoAwv0WeQu!riT&CDFRO!N
zre^=!ih>f(bz9E%u$Xe>@-1F%sNxmz0wqI`?p;L(;$}vDSG>ngdx5!x&;7xOaAVuF
z$=Op+t?^VCoChH+P_$zJC?)BMI~+xE%|d(C>w8GnI@xUV7+KapPxy?q8^5WP1P7XI
zhu!UOgPO}71twTXd2V$2@iN86qX;~+@QuxTra94Yd;+hSL5tbX((N}lKN2^jq(NLh0;ANFH9&}1AWJ-bF_IS#kz
zqG64tI`{$sqijLEF}vasasj!yYQ`9w%h`y)j?gbMsaO(z>ezjJC2tTBM13wNs@EXi
z#?K+qswYSVmC-$V)yo$8v)_$H8w8PMmd4bdv5F_i2)tMD@Oee$DYL;kGh)8k&5trb
z2QNt7(;{)l17U)nVg+hdDq5c?#ilBB7jr#QXi_%JfPOcTY+=)+kzI0;R&w(4s!zA6
zaKzA?bFTCE>K`JDeMrdMgoHm58Hz+H4&HTXmr2cCpu-Z`S<|248Vy5yYY#oa653XZ
z>*k`HT^bJ|K0qdYJO64kfv>foa?d!5YIKz^W;H;&pt7Ws3GTzOzx?;kqMSyzLXi+xdzgS#dTfdd|DP!
z5Fidb>rm5sl)!1b^AX-aL6?RS#eD^Q+GX~0>NXIOJScWF&A^aAvJI^6(1D8vZ&~@u?nZc-jpATx9;R*dJk?6>g
z-9mk0^ILK8(Pu*c8=vcMT2Y`8{0`Q8eJ4~jdm4SG0%fsqX~?{ShshtuC!DMu1TP0g
zUXE?$k;)2~Q`s{DG5d$W7nx-p{}gaP5%{0!U2vwVf2
z5l4{|on4vzj;kH&&s^taX3^@|icwLQHY1$c7LAh_8az;#|QVt6mNS9kCqYGp0gAI_N1bwCD(dk$Zh8Fad`M>
zA*{8wY`>dPfmL*{H*%6L$1HB5X55N<=q0H&gY9PsT+o9z!L@rs_ZFx
z>sM4q24>sFeA^IvplRW-=03o6Z4
z9MPe_ouekw)j2Vy)Ql7su0^LCCO8I8BEV;~Jp=^yxa#a#x0n`K9l)=3kz0zm=v7l{
z#`y~nYlgpDD-`FK38`w*)i0>&0{AX4LAdEH_VVq>K!bo=sM#S&=vRh_l8?HKwb4;%(G9oUqdiIj^}
zj!F@mMKJ}l$qI(t-<~ZZjlVpuEs~3+rR(1VfZ#~Kzd9-E8M>v)_>7VpsU&>rPuXTL
zt6Ck!P)7Hqf;5cDM>`F#BGS9l&}id6Z9Y8PN1mCe{R;Xq*;yau={d(~iUzhiA;qT8
z<#WRQi_OzubZvTr!TZWcVN^=3zUDN2RSO(;b@<7P^Km=cufZ>Gbd9Ufcgdyvu%W~T
zc-5)Y`w}6vDW)T!`Qs>Sz|V89e^tDwC#(K*>wYAl=`I}ieefE(3eBcj^mWOp$9ARr
z`qOZf3u5JHit?SWM(1SV?cuuOSuo@#nIM-RMU<7
z*%2+J9-pU@{TE5wX)@VkJv4s?Q(ZY>7ZGWthTe;)EC|CHHiw_O8sJpw7sQ#CPA9PD
zG8|3&i(k~h=6fZ`$yTN&zfy^;=EJ7(QWanGHkR*&6U;^_8^iD8U7c~)H1AVT1L?r<
zW$|Jp)hW?=r$mYIs;gxvrDvGX3yTV@di
zd?b>XkRX#v3EsC%(tD&2#@HnC88d$9YA9kR&A=v!&_lu9qC+r7r%uxulOV^*)
z>Mb^`P;}Z>V1j`+Gks}erl9u@3EyT3R(RpLm_pPoLVNa+L-3*?4?05ccF@pF2TaxU
z-hFh${Ng(S?#z^ezxb~>_lWypiAuGu9iub$xo0XX&
zD=(0_N4bz!qZp8Ack8DHQIJfvTh;kVF@TQ8!
z`clpp3B$#4DSV)#XlfI=gnk~KWlpwt_%42~eZs?q&{4X{Wl2z(6y{7>HBka3pr=ce
zaTY>9OypP4u0qOW3V4)6c7DgMLP-RMOKZq1XmqSaen1o$P6(neQ5cH6gI@avKLde-
zS&qwyQGejX9BLylrp#e|O|&Mll{Z<&c}Vne5Rk8iEv&c>S^Q0qrj%0%vc@(d`3nD9
zp}%9U%eO+v$RuH)s4!s>DbUo{i6qC#Gl5CP0A$F7g2(Gu5Q*9v=Q~`18AiY%X$E5cv^65G7{WyN(VS&v>~B7gXnCx9ej>6&pLkD^T$&{UF!<|9>+R<
z+63|K;@!@~piokQBhfG$0PdmGP*ktgI_^r7P(6K%Y`|69pN*XwV*%8n7c
z?vpK_a1txWb<-SUDx4xeQN;U~PY81W>acTNpb|`?_3gsvQV(9*yW#NG!}Zl|s8Bfp
z#D|Kqe~-35?sL38905B29+%h%W>wmdHjnW@HJ>-OiNX7!!Q050=FUFdG
z;{7oqi8r+e_X>?+XEu+%eX9Rub33X|r%;AIpKRj5RV^-VScSg&2vkMreY{tUL=zrO
zSoK%PuR6Uhoa!3>#c#6{)&?4KZF!g9K6uRU5%e8Aw!eHqnm|{sj_%(Ljh52wD?a)N3vEuuji1(k@inZ|#`v5NawP@9q3s?SW(8!x!saK?~h4M*Y!1rD>iK)>S)t
zd~o^M*fY{D??<73*=620TPT+AI?sR5SGC-mp@t=M554xb`ET;@+aT=wb)atV8n*CS
zTS9%&!~be&3Skn`XX2dYu$LE&a(%!Ba&jZO5iIuy#r$?tHaYO$vcun#_5W*j;9z9^
z&+Jg7Dr>*a0NeGfelP>PG=rI%!@=gJ#Rj_~H?(GmnTn4p9?AP+o!VLTj`lfLEFu}1
z<09N5Eu>ig^70}fd&Q@>(bfKPROcmJvdeR`g(!u>f7_kV+aanhZUKoeXK_%vCI7Ks
ztDkHEkHede8;|a4CePeA^_`_lXrH>E0kPXQqS@H>3cIBlFOo=u4Xv79xm`q>ZCE(i
zMDpR8&uc3gY-4fv$6=0zCs{O-)Cv<}_p=bT0MKv#1P$U#&}u9wN}om$54aLH-?$WS
z={1;$RTN(!=0rMkzfK?9A6jS8YbMzs2@A{yV3jYBbdrU{HfZmg=>hR11-Z4Il3_02
z>FM?{ZIif;RPXvsu-^m|N|HO=1qPfUvCAUEVop&Xrz4BXUl@)VAkBxBo77-#jvcV$
zl3oZJQHT7j78-%uzhbI^{RTo?8b&i|GsL)5E`C+34VKR_2!75mI%1J4EPus|fba(@
z$Q@^DY4Z2Tjdo*V2c*Wmt2@1^SAs|vfbd#@LnK{(NOru1aDayMvE
z=vUa4E#^P5!yWGV9y{RI|AbAhMjw?KeLRo+sU@K%Ju-%viv`uA)oX@E%(10Vk+HUL
zv{n@;ebHGf{_}&u==fpNP1NkMA?SyV$HeZ4)U;_aSteV9kBT_4I#?4C6+63^eHHkI
zrf4nN$?Qq<`1K}8_`FqFwjs2T9IM;CsB)SB{aW5wkX39&__w7JhrCgrBge6*xzyLj
zeH{k93#W{D^S6BA^x{#IqTZjQj}n5%#vE~=;ZA+C6fv$pbxB;xG%OgcVfWe`gS;U`
zceM^g*YBow;eF~aL^alBEY&S1a)QO*ZW68h8yh!fJ5Qs=Maw;!!nTMVL$3L}HobkU
zc}dUDY52p8_rHs6dSgUcgicwUjm+4}{q}a^&)|V#_?7|0_U*rc?myo6`Y-Q`{Of=q
z3p)q<|K%u=jf$rNiYE3o*(p3UDA+o(*8C;o)j|bDb<}hjb=-2jb@#Pl!{fD7DgK%k=0K%=1XN?Bd;^~(|eXg(ji8t
zhMF{vr8K0G*Eq58nV#YCb3eG^VWhjyU!OY({2O?y>~0GXKp`^qUITd5Dy94v{o(Nha){9(CO97mv
z^y`4Sy9NgHxvz;7z^$+cN(g}o~Qz*w{in!GYslsS4X8`C4A{Tnnl+8hhu6*LvsX
z08=X5K8HYxlix&HL07wQ9A})-%EN_U_k2L^_f#@o`)g{0$rh+S?$(yF{nl
zSu&JU-DJ~}uGoN{OQ;}BL1vIT%igRhcu{bN)NVD#UU3>@eD}w%SF;>Rw-T;r`HPKB2iX{oSjTeijOVo&*SH&F8tc?b?>zEb~~-aNO(
zKjSSu`0z5uzHCdtCv@7xlF)0x8x@R)<|I-cXGRK(?wPZ*JEqXU!|nhf(@9U
zl%nf#r8SjlWj|H`?FRmF+qdK1=U#}P4@Q+DZ+sQ}*?;7tok*G~oRURNPd7=$2oY*s7^)(v
z`QR99g6sMihyL_fCz~}lo&f!Q$q|e)rpvFVHgAZ_ILq8tztIsnGt;bW#r!B`F_pjl
zgKNw`Nam=*k=8iI#jaq;#wa@K;`^%3K*>b~2YR_8U6imI3lzm*@AY=z4}vo7>_9za
zcKdx;H)JK)e$*9UWVwAP)BI`=(7)Nz!&5OcyIw;1r
zp|BtlN`&0ZLT1uXDe&my(=0?uEKB8CD%2RZx(Q3`%$+J%6)vk*LWuF`xa$1@1|u|a^hES$I3mBKrs>O{4>)_nS{b8(5cr{_1?v{1}x
zH3Rxu8?nYL0%m>tE%YPGEx2+IY%Ky}9h)tGGke-+(m^F*muHNAD&z7d86~_RWZYf*
z(0Tu1uJL1ojYJKb>XY@DBLXg~gKLuo_KVC@xuytg@{Jp&&;7HjQ7T
z`KKwQuIh$N@4Nlty#rN~`@qq<6O4&Dp*d2o*P=;b0XfOMZl8J)Hv3cIn<)PeHtyL#Bdc
zCki{&XO3W)lFOhPBNqj;5htP8EKxavHEZ5wEx926j){f!ZwP3yg6`aJAuzcP)t~IU
zn9_;O46(x9S8ct*SFplu=he+8=`OU>(Yo2U$qE
z-C!Ms-#2FWI?r*Nl(zVcFc0%VwEHf)g;{49lsg4Jn+msHdDml}C>SImbiKN)ezq~R
zk%Z%1V*a7D;j?&HvJ=TC#?)>&=
zi}ozBI<{yhjdH_NDw%ot@9+^dd>OKM+-}ok2H{T;dkLTVPPfW4d>6I?((%*jol5ohqJw(g`fHL(A;i5utQa8}f`Puexe+Nj9%l*hYrt-Hw
zlPN5I&>PAyRLPAYaX;E}L~W_FI-kR0ub^o9xdTo?1u#{8N;n`RAq-52B{9+0=toWp
z*y$%DvNH_}`k`hGV_X;IvNmg#7stGG+%Y+Va1WOA3etU;{oHxXOAo%ZXZB1NyhGCy
z4~M{DOC#)>Q;=)%TJ3YYG`Q#>Sdd>KB?LAjKXncZjH9>;-$zje1mQ8F`innO5tlAg
z4V)_N$0pzX);939Y5v+GcFwdOG>uWp3QFfS{aQR~yOAikRUS0}TKMoZPb&Zvoeiz8
zen+-UI%?cg3aAMk-iU$6cmhHi3bWoGcamFJHb|Zk(djMd_AF*PEMYn~T_)p0!i}VA
zstEJTD!M|f#iCaKd%-I<$qAy03zPDn8IZP3m!@pkKJ*Xn3oa6(z
z70E308uT@N>4)2b10DQX=IHW%-K!r1u}P9f*5O{4M9L?cW-bb7HUxVwXe-~8U>0F#4Vn7>W~X7lm^X24}cWJh0jp
zm`7xpYqparksSRI6xA6mFf@q$S}n~wgtm8k#K*du&6g#-0fgM3n?U$M5_^~uWa?L*
zxcv4Nf3t&+N9Ur2&|);RLn6BGWLXn?j2e_^0+)Ok#lH?{!>J+81vp;7R_?nx#4`oJ
zdv%R|i$LUffgZRA^*z9-844P$@d>p~LvP)vaaLxd=Ks_W;snX=PDyrjpoI;82ViaX
z_};C)sM}srAb&%*t}=uD*wx-saYf2WdD;!OhCUz}RfPPc9I;DVOD0Q5
zcC(*d+gQw>2ju<0@AG1KIx%*q>Jt;;+Qbk_3p&a-<-$H8CK>&h8OJP@I#lTH`1tT(
zfzDGlim#G?T_wjtBRKv@yG8jwjC}=Elu!4tsDNMx1}GqcwcvIk#&&nFx!v7sV4;MF
zA}FY+7}%{?fZc(u2r71qy(-4{h;N_Y_x|5Kdd{-X%$>P?XYMm|y`;SzJB`KMdDH)h
z<861)(0;ScqmuW(81b@A-U{!27rf03Ov=PUJ$K&6j1^z6`%Bt>Jk;$(Q!bQuCzX=7
zmAi2WsJ6WE#`C7_(5j!>53>%go0R?9H>&lAkpspK!%mPMp)Bgw#sy6Q-Y3E+7V8Jq
zW4es@5;>lQ(DQ?|>y*YcWLms8p6IQWDn~sB~
z^jx*$K(EJlT<6;_N+8uAl}x``u*%XtYS;b9)EIg1$)aG~{ax+XYUJTdC`%`{uUKHB
zGq7WaZ+rdS-c&Y{Z$9;);Au1?hrg%JdSy(9p3!aLsm!^xRmd1t=kj&dq{GU-!yYP*
z$|BYgnv$!u^8Po^@*?n+n+lUR5`$r5ww!ww|MdZ7PH~U=qUgDx(-(`n7}s8Yf;#wG
zxBk-on`1AoSatGPG%o)A0OG@F;=|OL*RAXO-|o;lu%Tk?nRSbf>28f^sJr3p0PbQFH!nEqWgU~X1Mc~5Wkc&$KnsV;J*HKT+UPS~9J
zDs~=z)3v01Z3-!48si2-F5;TI#T}o&v|bk130xGjFi`lBRiFK+YB{L()aa-ZK`
z`}c_NGHrFo^MyYX2bz<55AK>OY|5To-35*@M?~^;;O=zMi5L9nVI}NtC*BV22CAEm
zr1dD_&)Ihp?yWu}WO%|_gEpClS_U1q9pP_}+gysd(Q8ujo@ZqRZNwMH9ed&EJZ(O7
z)%B&l+TR=7abeHLSGQ_+FK_!GDGAk1edqxqV?xUDkEl)Mi|mN!xsT+fm2>5tCn#6%
z`!&RS@Mk3V5)RV=wtqVRdHzxIs+-PDM?l`9SMspX>%r~D4mP!2tbcpzWa7XccZM8|
zF$Qul)!KN{XV>i^<+;LV0zeGlzjyZX?ePvg41&bV5>^8-^F*N*gSWNia6YR^vo>aC9+
z^&8&%$I~m(Q@`~7+39=VO^q8!c(yUGYOCwgDD3EH$K#)O!mrR5U*5V`{^OAflACe)
z_UzZkC)B?DIq3-V_PJlJ!&?9R5o7dNjKR2Bvx*8!7Ob7ODB*SJS@`G(7b?6tN!6U(
zj1$C7n_9G_tgs?~VZxPAc6ff#g5@h>SNqGnY#u@Hk1w4bzGd~w4Xfhj7tNV8EjMd+
zFoE4{2%oOYa{6M`ORRgELcK$VlQ*?ef~mT)%|UVk{J_93L)NwP!D1^de@uXzMwn4`?HDrmrjfwD{pQE84#ZvN|_Og
zH{WlLarO>G<}X+{w>WOq%Dm>VB?(281qhHk*X@q&G~sWMdtM9;KNITpM768?moLyS
z<6`^~&64J(@SKGa_hyAy*30qDVb1VC--<1pwnOdnX1e0Lf7I9HD`Uq?M*HpYGc(l!8NXepW|8+R
z+!q$2)rN#S+Jt6@Uc}_C%dU2BZys#zJwE`DaB+TV+{t-7TtULp=F2l@C}q>9_nACp
zf^y2Fx!c@vcBjqdTOKIiwzJ~E-h{e)q1K`P=bpEF6H3Y3yeYQIy4>aU1?|4lNsDI3
zS7a9CPLIVkk80LTofO|`($X*dMU{(zmSsyS@*0N
z9XeDmEm_y0jTwk2m8&PQ*wOHOT@$NzB}#Yly}2NuwDC(KB3(iULQUp+;uFrLrnW>KdO%(;zRaPdm`)vnLSep&VvpH|Z8=$^(U_LP0f
z^ICbz5XJkkn{1iY$K|>)--`M;>UOpt{-jizRvF%%zZlY$x{o@`|ELwVd)AuMIfM6;
zhZGi?iuzV}qn-H*x;R19O}r&6rCVsu9It6l^{mkX>A-`CrS%QNt2*8*Sbp#<=v&u`
zt;!!S&HXDb{JvJ_|KtJ|xIgM&&MQwAU2xW;y3{leaVMwuWXxET{CVIy#{=dAZI3U7_?0+xD=#*dtnbrtu%~VO{ADLKZqCm0&0SV{doEkk
zr(tAc=Q-`44=d?eT`O0KVuydEeR_9_H#
z3SS&CVlcR0x@?as?NQF;Ps;`;DoP*ImXXt*och#=d{)Pp-7~{*IhRfY$FCYuw>fdW
zNb{zMYSzpyo{bUP
zuwyH4kI8FgtKTtd(DuD&i`sAGw{AyHMR!WMHtOxF6SU}Q3#|`4$wLRu>Gct|!7_O@
zVjp}=Kk^Ndeb{Y@2}Q|xSu^0&E4;t%dVKaZ%e{@Qimd#+Y4{!)qvBwpS)0T)4gEW=Tmuz{Dqrur52v}se3!w@^)vX
z=%M|2Lep4}l%#LH<;LiAgYrdh=I4%~Qy|UQj4<`?!+Jxucyn~_Wpmo+d&?C$XzaMa
z;M&Pkh1H6cpLd=)_R+;ZiA}AZxVH`cBL_YtBkJ>tuBBOD^*`U~Hv5;6##;FW$!iR#
zp%*I}kGMc%7hmv?J6iPwF#){Dz`T6n;^!KrHT&)+o-1`<{ge`DU#qeAc7r+X<)MY1
zsF|M%xxGipYcA>*G$(tn7W!IET`>7gYMvyjOk*K`*#o_HrWd#KPzOH>c&6yWf
zaQEZo+b42LCne{We96z*;jFvTPIlwsr?b~KpGSyy7Ik%&jPAT;ZydaK>8bOf9aXYy%GCI#hpmp59W!h%n^&X(orsnl;Fk?uoPG1G6Y-vD
zn9c9ow`c5uwxf>S?K!8Z==J>U?uh8M>H>V#$lkv|VO(o>M~UqMYHMnb?ziyxR=v6ZYohZ0?Ng5=QDN3WeAmO6&tAj=kU3
zITYL2^C{{@&6OW3Pfaf1UCDb@v55EjnyL2A@_QEse@d$2SFz83pZMg6ti5i;=S9JL
z@1s8S`TTY9wmVIhQN{N{KV?5xT$t*rG=6?L?bfuohIVhBFM&RJ+&gkquetkg`XNEe
zW&vkCVjEOyq-9A+iMh20jWz3+?W??1;d^%J
z#jjnsFxp1j^}hSUW@aFIv?Ua6Ydr}v#F+i<#K6{f?@#?Pp=|lDi&3!Bn;CBwsy868
zfzjzO7z6@K>PM}Z(i|(AU!i?r4jXj8Z0+Gg(RUjbLwcb;bU1Y|V$vi|StFOPhNUQm=i%dcF-
znq8jyT795m|GK~gi7W0$j|}Q9;`8~G>*@JJFE(EJxFhz>%?(cq=jF{neVD(4HEVKX
zY_0mzDrUDWt03Fo&N%i%y(hh~>nwX@KOKZlh$-q>Y?{?@=9;DY2ws@KTOWCrGx_?-RcL{Bdl;4??T1rxV)s7=*vjW8QW1R%rWE>Fw2Pza^Cwl3VM7Kfg^5
zj76@9a-M%s%KFljU`-_@=T4aAE$_Wy(Kr6$vUL8YAH-8PA(RoeBa!Nnd)j?H1n$`N4cF^NFIk#u`G`4v=HRu`
zg|elfjh^f2x9N9gR7-pJTq@XyFQMOQ6Bj55LYMEWS#+U`(Y|Bo^~3#UiO-KaR$cd9yI?rd5vEx&p5S=I{U0ZV?*O(*F0T7^fk
zgb4op1;%{X&W#%@_@^shs$NJOxvN&M+flh^>%xeUv+r4KHT-ir*~4LC&gHEzLX`GI
zx$w%~+6w396AQN3+kO~pLEPSQL8&UW4q&-=Xctp6FjKd2?(tK54L(N~TqWt-N25!I;U}N6n4XGOn{Jlf
zi}mkbXZLMTY#Zh~Fnx8db9KeKPd%5)t2V}{UdDZU240zi;1w)bu)D`~P;}e6jokjg+<$xeHyRJjHyRA4N8vNt6rToc)NYu4l#k_yES?4h&Uz(xtGFk_T$glyWBS9+|V1nex?gAYt9x79ebR%NIH&}6Zw?0r1IyF
zibV%H1zW#AJvEy8v1Ar!d*Vgu?edZbt3~Ib%P(Jj-!nR$HC$HEz@tk+%IrQ&f;W(4Q3x#=qRYpdd2J
zv#ap_3$f+E=tHe13}4)FAn4ejIn2m+qh@w$hTQzxT6E7XC&V5+PZt;nY`pjr1_F|$@9BV6O+4kA!0?Uu6ybQWA+@pXC4(~>sHpkPWUnD;@7F6gZF;inM9dd1nIX4RZjZ#E&jpN
zX@@S)kPF_KyI!d+-+QGta!~a!--RpL1EFP&BF3tW#P_!(4Xw7U>K=d;uBPIcecu#w
z_pFCUjtRSYIL&^wsrhQ{inyq0>wP)$?ngg?`1wU$F1#xo0PjxN`^_$cJ7;yvNiLC_
zEIY0TmP~Hju;`*-Fd~>j7VU11dQ)sdPkTV7n4Vvrk~wW#)4hUSAHH^LRG;>D^51U^
zS8e8ZY8ZEP-&nVGgfUte$@#;;oFizamb|R89LezrW?okGaoho$t$AG!1e+Nn;8u
zw3IC$zxfnzS=Hn=%uK@8ExSZry823oLQ}#_<9Iq;a(39LhVRf&0-{%HV;wBWNX*+9~$g24~l6^+C%pB?gk0_Iq*$e
z!u?x`*$0mE2Xk(=KYMwz_5JzBAKHW>?WFF7vY`v}pv7DF19>^^2wHM0BzWrZhLQ!-
z&pkd%j~U3@`*p?C*7R*&>~WGSq`C_}-?`^>R_!rEMeyy6BR4KZx!2!X*g#`svU6rCn|`=+`G-
z|Q(enRm#@sYy}9xys-dy6w4rxsMPu=(ahVTbmBxnSk$wG%OMsW~qTfUL
zil&{{Lx=-WA$^W#31loNPEfUxsS3nMSZWnypX@>_LBozW$Wa`
zv~{snW%~7x{1@ka@j!DhyAryZYVPckcIn>9?2Z7{O=(@Nt|?x&{oESkxKbYJ)wGoC
ztqp^VR|kIXEZ@NWwa_rOZbtG``iUMD$dWA`ksbHs&&u!gerI!ZMC7FAh{*PncI|2(
zy{qm}PV&2(o1R1*nJ}tN)~n6Kv^`x9+fHZCowcznS1Gi0k!*Oo;^Ht0N}$1}l-!lx
z8!auX^j=FbrGZkrM@4_XHdty8^~@eVGwxaK&)Vequv4a0ZNJFIKTkA`J+t{!7PAvF
z>nvzZQ;!E&(xGo<8+snUl%$^Lz9H`5ij$u~5B&1@wlw^u7;v4arX0
zA}=5P_0%~wi?gh{&(Roha$XUae=;HM2D##Edb@#&U#mL}yo7t2!Lzk@cT0YD4T2PJ
zis=Sw9F%$-x;^&mu3u0oCRdgw-*a}ffsrl69Gx21YFkp`(6>EO`<_}7`!Tkdx0YJg
zJ%7j1HJeB8-@5vaW_sT0@`E$?RfW@HhN|gI9c0WB5
z+1N((s{dPx2D(Sw`((DdbYjG65YrJ}i~;cI+8*;_H#u#7{Td
zPVC>idH1x)W2(mGYWl;E_Z!x&xEr>1@f~m11YP^#y%$a^E8HI0!8Z}?GJKht^dxc4
zk)-acT3_%!wq~@CCmv3k_psah!&AB^PkN1%O_3JZJ|3l=9aoghp7nVJ)c0X}%EI)3x0nL-JbjnyKvh*9ea`eT10-$sx`aQ+SixNKaeMl^1oDGZp|Ou
zIH>cea(gb}sPy=m9Y|?hzdO~}46iPqzg~Zw(L3<9(+h*k3vHa!CK6Q=g&s6R{XOsN
z^|2Q-ZY592!cKm@WY#NxDeS`@)3Q6;V|kiqjJ5iyXsq1;^S|yf%tqdMmdei+y;!pO
zLdUJaNJl^ag3j1}P5tY!U0v%wL6YLvefn_zP@BxTn^W($d)H(0m5G(5rr4KD+HHLq
z{-w1dx#q*YR>al^f7r5o;hUFbx-)a8&vYb+n+f4Te$F&ktfIoY>-dE;p-z{83_|#1
zad4b9*0|ogAjR?huqCx{h?P7P@*XxWfrxD0E
z$PZ3Z#42Q
zc)vn{d`4=heQ3zZHlfwWV`hX-PCjhd4dB9}eL{{{6Sw9eq-~b6f0g>n4rS;V{JuE*6iE
zH5i%AktX-scVEUp@MrOz_I`P|u=OY7dGx&X
zy&AbaR&LLeer+gOtR+s|`Wo!5?=a)-Q-4CnkY%g^mr(~Or?XNMN4;y!-Slh!;P(A2
zgAT>UJ*F{=mkoM4V_*BpDQs2k?yhr|T)W;kOLsea*Wo@p;uf`!Uy+-wee>dZ(`HEe
zSjOOcDfD&Ax~yB5xMJTlMtlHBWWIg$WSo6%)^(S6
zX)kV8UfR-^$t;V10aQrt6t7JC;p`N~EXU~lU-WSX`
z{+fGzF0)=1WE^{yx$esQ8{69)d6^l5JX134N(z66%FsG{#OVIfCoW$(JbM1>l^dU=
zCV@7r&8{l18coIN$1fTZBH-?}8im~FUAFORF)MTM!6_5IVtdxNKUo}-a^~?j_EF7_
zE{|)v`{GjGHE65lFXi>9_d(^3&)@lvJK($XSEMg@t?e6X&4R2&44PEgkQmr@aovmo
z=T+)h*u(u-YKHF<^nVl&$_2f+uHx;FGkbLrgYhoQvn8gt%`%qd@%T_~*cHW^dL
z2WQ^7bZ&O|gG6rjnA3fFJBGLJU)XBFaZu-XhM|qb!AhS##;@G5mbdK$y=&R)^0=2w
z^|wy0&Y$(`A2QRY2_wEu@*|%g`gLt;ey@Zr_|N@fi^ImJd>F9jso~tpy={*z&oJF>
zD9^nomW4qb61TjOvpsk|C%g4|3tQ=_Vc}jpUJ%zZMkAe
zx|Q^G;tAND6y2i7YyfsB-uT0>$$A@hG(CUn(Pi9(n5hsz>?S9uSVK=X=)n?@A
zi~FBlkU6UacQ(ALb4HHeYCae?lk)wLdQ;zdry2|GA0wCbA}vYTqAueF$)8^Q?2?k=Y==Fu^upldx>lq)a|+{*FV>Fjb*b~C^_s6q
zY2D`jm@+@2AG+16>si9lz)7!)sRuMs3OBHGjFg=CZlpmyVyBc0B8qnE{Qv=!|-|-q;usl|Fy#vG=Wd
zVf{VP`U83Hxj|WUzcY#L18tel_J2Cr%e4RWRn2>I_j_BeX0YS$9uAlcFFG2`szB7v
z)Ehe<^TrfT#6B9ZCuIpc*f=)(EUW0IbW!!1+3=jQYx0dh9X&22A877+yxsm6d#2w~
zE%pCsPTU