Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
27 changes: 11 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ dependencies = [
"watchdog>=6.0.0,<7.0.0",
"opentelemetry-api>=1.30.0,<2.0.0",
"opentelemetry-sdk>=1.30.0,<2.0.0",
"opentelemetry-instrumentation-threading>=0.51b0,<1.00b0",
]

[project.urls]
Expand Down Expand Up @@ -83,12 +82,8 @@ openai = [
otel = [
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
]
writer = [
"writer-sdk>=2.2.0,<3.0.0"
]

a2a = [
"a2a-sdk[sql]>=0.2.11",
"a2a-sdk>=0.2.6",
"uvicorn>=0.34.2",
"httpx>=0.28.1",
"fastapi>=0.115.12",
Expand All @@ -100,7 +95,7 @@ a2a = [
source = "vcs"

[tool.hatch.envs.hatch-static-analysis]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer"]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel","mistral"]
dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
Expand All @@ -124,7 +119,7 @@ lint-fix = [
]

[tool.hatch.envs.hatch-test]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer"]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel","mistral"]
extra-dependencies = [
"moto>=5.1.0,<6.0.0",
"pytest>=8.0.0,<9.0.0",
Expand All @@ -140,18 +135,18 @@ extra-args = [

[tool.hatch.envs.dev]
dev-mode = true
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer"]
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel","mistral"]

[tool.hatch.envs.a2a]
dev-mode = true
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"]

[tool.hatch.envs.a2a.scripts]
run = [
"pytest{env:HATCH_TEST_ARGS:} tests/strands/multiagent/a2a {args}"
"pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a {args}"
]
run-cov = [
"pytest{env:HATCH_TEST_ARGS:} tests/strands/multiagent/a2a --cov --cov-config=pyproject.toml {args}"
"pytest{env:HATCH_TEST_ARGS:} tests/multiagent/a2a --cov --cov-config=pyproject.toml {args}"
]
lint-check = [
"ruff check",
Expand All @@ -164,11 +159,11 @@ python = ["3.13", "3.12", "3.11", "3.10"]
[tool.hatch.envs.hatch-test.scripts]
run = [
# excluding due to A2A and OTEL http exporter dependency conflict
"pytest{env:HATCH_TEST_ARGS:} {args} --ignore=tests/strands/multiagent/a2a"
"pytest{env:HATCH_TEST_ARGS:} {args} --ignore=tests/multiagent/a2a"
]
run-cov = [
# excluding due to A2A and OTEL http exporter dependency conflict
"pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args} --ignore=tests/strands/multiagent/a2a"
"pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args} --ignore=tests/multiagent/a2a"
]

cov-combine = []
Expand All @@ -195,7 +190,7 @@ test = [
"hatch test --cover --cov-report html --cov-report xml {args}"
]
test-integ = [
"hatch test tests_integ {args}"
"hatch test tests-integ {args}"
]
prepare = [
"hatch fmt --linter",
Expand Down Expand Up @@ -230,7 +225,7 @@ ignore_missing_imports = true

[tool.ruff]
line-length = 120
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"]
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests-integ/**/*.py"]

[tool.ruff.lint]
select = [
Expand Down Expand Up @@ -290,4 +285,4 @@ style = [
["instruction", ""],
["text", ""],
["disabled", "fg:#858585 italic"]
]
]
Binary file added src/.DS_Store
Binary file not shown.
96 changes: 59 additions & 37 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import asyncio
import json
import logging
import os
import random
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
Expand All @@ -30,7 +31,7 @@
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
from ..types.models import Model
from ..types.tools import ToolResult, ToolUse
from ..types.tools import ToolConfig, ToolResult, ToolUse
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .conversation_manager import (
Expand Down Expand Up @@ -127,18 +128,14 @@ def caller(
"input": kwargs.copy(),
}

async def acall() -> ToolResult:
async for event in run_tool(self._agent, tool_use, kwargs):
_ = event
# Execute the tool
events = run_tool(agent=self._agent, tool=tool_use, kwargs=kwargs)

return cast(ToolResult, event)

def tcall() -> ToolResult:
return asyncio.run(acall())

with ThreadPoolExecutor() as executor:
future = executor.submit(tcall)
tool_result = future.result()
try:
while True:
next(events)
except StopIteration as stop:
tool_result = cast(ToolResult, stop.value)

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
Expand Down Expand Up @@ -189,6 +186,7 @@ def __init__(
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
] = _DEFAULT_CALLBACK_HANDLER,
conversation_manager: Optional[ConversationManager] = None,
max_parallel_tools: int = os.cpu_count() or 1,
record_direct_tool_call: bool = True,
load_tools_from_directory: bool = True,
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
Expand Down Expand Up @@ -221,6 +219,8 @@ def __init__(
If explicitly set to None, null_callback_handler is used.
conversation_manager: Manager for conversation history and context window.
Defaults to strands.agent.conversation_manager.SlidingWindowConversationManager if None.
max_parallel_tools: Maximum number of tools to run in parallel when the model returns multiple tool calls.
Defaults to os.cpu_count() or 1.
record_direct_tool_call: Whether to record direct tool calls in message history.
Defaults to True.
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
Expand All @@ -232,6 +232,9 @@ def __init__(
Defaults to None.
state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict.
Defaults to an empty AgentState object.

Raises:
ValueError: If max_parallel_tools is less than 1.
"""
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
self.messages = messages if messages is not None else []
Expand Down Expand Up @@ -260,6 +263,14 @@ def __init__(
):
self.trace_attributes[k] = v

# If max_parallel_tools is 1, we execute tools sequentially
self.thread_pool = None
self.thread_pool_wrapper = None
if max_parallel_tools > 1:
self.thread_pool = ThreadPoolExecutor(max_workers=max_parallel_tools)
elif max_parallel_tools < 1:
raise ValueError("max_parallel_tools must be greater than 0")

self.record_direct_tool_call = record_direct_tool_call
self.load_tools_from_directory = load_tools_from_directory

Expand Down Expand Up @@ -324,14 +335,32 @@ def tool_names(self) -> list[str]:
all_tools = self.tool_registry.get_all_tools_config()
return list(all_tools.keys())

def __call__(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult:
@property
def tool_config(self) -> ToolConfig:
"""Get the tool configuration for this agent.

Returns:
The complete tool configuration.
"""
return self.tool_registry.initialize_tool_config()

def __del__(self) -> None:
"""Clean up resources when Agent is garbage collected.

Ensures proper shutdown of the thread pool executor if one exists.
"""
if self.thread_pool:
self.thread_pool.shutdown(wait=False)
logger.debug("thread pool executor shutdown complete")

def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.

This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to
the conversation history, processes it through the model, executes any tool calls, and returns the final result.

Args:
prompt: User input as text or list of ContentBlock objects for multi-modal content.
prompt: The natural language prompt from the user.
**kwargs: Additional parameters to pass through the event loop.

Returns:
Expand All @@ -350,14 +379,14 @@ def execute() -> AgentResult:
future = executor.submit(execute)
return future.result()

async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult:
async def invoke_async(self, prompt: str, **kwargs: Any) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.

This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to
the conversation history, processes it through the model, executes any tool calls, and returns the final result.

Args:
prompt: User input as text or list of ContentBlock objects for multi-modal content.
prompt: The natural language prompt from the user.
**kwargs: Additional parameters to pass through the event loop.

Returns:
Expand Down Expand Up @@ -436,7 +465,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: Optional[
finally:
self._hooks.invoke_callbacks(EndRequestEvent(agent=self))

async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.

This method provides an asynchronous interface for streaming agent events, allowing
Expand All @@ -445,7 +474,7 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
async environments.

Args:
prompt: User input as text or list of ContentBlock objects for multi-modal content.
prompt: The natural language prompt from the user.
**kwargs: Additional parameters to pass to the event loop.

Returns:
Expand All @@ -468,13 +497,10 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
"""
callback_handler = kwargs.get("callback_handler", self.callback_handler)

content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
message: Message = {"role": "user", "content": content}

self._start_agent_trace_span(message)
self._start_agent_trace_span(prompt)

try:
events = self._run_loop(message, kwargs)
events = self._run_loop(prompt, kwargs)
async for event in events:
if "callback" in event:
callback_handler(**event["callback"])
Expand All @@ -490,22 +516,18 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
self._end_agent_trace_span(error=e)
raise

async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Execute the agent's event loop with the given message and parameters.

Args:
message: The user message to add to the conversation.
kwargs: Additional parameters to pass to the event loop.

Yields:
Events from the event loop cycle.
"""
async def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Execute the agent's event loop with the given prompt and parameters."""
self._hooks.invoke_callbacks(StartRequestEvent(agent=self))

try:
# Extract key parameters
yield {"callback": {"init_event_loop": True, **kwargs}}

self.messages.append(message)
# Set up the user message with optional knowledge base retrieval
message_content: list[ContentBlock] = [{"text": prompt}]
new_message: Message = {"role": "user", "content": message_content}
self.messages.append(new_message)

# Execute the event loop cycle with retry logic for context limits
events = self._execute_event_loop_cycle(kwargs)
Expand Down Expand Up @@ -600,16 +622,16 @@ def _record_tool_execution(
messages.append(tool_result_msg)
messages.append(assistant_msg)

def _start_agent_trace_span(self, message: Message) -> None:
def _start_agent_trace_span(self, prompt: str) -> None:
"""Starts a trace span for the agent.

Args:
message: The user message.
prompt: The natural language prompt from the user.
"""
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None

self.trace_span = self.tracer.start_agent_span(
message=message,
prompt=prompt,
agent_name=self.name,
model_id=model_id,
tools=self.tool_names,
Expand Down
Loading
Loading