Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 45 additions & 77 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,18 @@
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
"""

import asyncio
import json
import logging
import os
import random
from concurrent.futures import ThreadPoolExecutor
from threading import Thread
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union, cast
from uuid import uuid4
from typing import Any, AsyncIterator, Callable, Generator, Mapping, Optional, Type, TypeVar, Union, cast

from opentelemetry import trace
from pydantic import BaseModel

from ..event_loop.event_loop import event_loop_cycle
from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
from ..handlers.tool_handler import AgentToolHandler
from ..models.bedrock import BedrockModel
from ..telemetry.metrics import EventLoopMetrics
Expand Down Expand Up @@ -210,7 +207,7 @@ def __init__(
self,
model: Union[Model, str, None] = None,
messages: Optional[Messages] = None,
tools: Optional[List[Union[str, Dict[str, str], Any]]] = None,
tools: Optional[list[Union[str, dict[str, str], Any]]] = None,
system_prompt: Optional[str] = None,
callback_handler: Optional[
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
Expand Down Expand Up @@ -282,7 +279,7 @@ def __init__(
self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager()

# Process trace attributes to ensure they're of compatible types
self.trace_attributes: Dict[str, AttributeValue] = {}
self.trace_attributes: dict[str, AttributeValue] = {}
if trace_attributes:
for k, v in trace_attributes.items():
if isinstance(v, (str, int, float, bool)) or (
Expand Down Expand Up @@ -339,7 +336,7 @@ def tool(self) -> ToolCaller:
return self.tool_caller

@property
def tool_names(self) -> List[str]:
def tool_names(self) -> list[str]:
"""Get a list of all registered tool names.

Returns:
Expand Down Expand Up @@ -384,19 +381,25 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
- metrics: Performance metrics from the event loop
- state: The final state of the event loop
"""
callback_handler = kwargs.get("callback_handler", self.callback_handler)

self._start_agent_trace_span(prompt)

try:
# Run the event loop and get the result
result = self._run_loop(prompt, kwargs)
events = self._run_loop(callback_handler, prompt, kwargs)
for event in events:
if "callback" in event:
callback_handler(**event["callback"])

stop_reason, message, metrics, state = event["stop"]
result = AgentResult(stop_reason, message, metrics, state)

self._end_agent_trace_span(response=result)

return result

except Exception as e:
self._end_agent_trace_span(error=e)

# Re-raise the exception to preserve original behavior
raise

def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
Expand Down Expand Up @@ -460,83 +463,56 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
yield event["data"]
```
"""
self._start_agent_trace_span(prompt)
callback_handler = kwargs.get("callback_handler", self.callback_handler)

_stop_event = uuid4()

queue = asyncio.Queue[Any]()
loop = asyncio.get_event_loop()

def enqueue(an_item: Any) -> None:
nonlocal queue
nonlocal loop
loop.call_soon_threadsafe(queue.put_nowait, an_item)

def queuing_callback_handler(**handler_kwargs: Any) -> None:
enqueue(handler_kwargs.copy())
self._start_agent_trace_span(prompt)

def target_callback() -> None:
nonlocal kwargs
try:
events = self._run_loop(callback_handler, prompt, kwargs)
for event in events:
if "callback" in event:
callback_handler(**event["callback"])
yield event["callback"]

try:
result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
self._end_agent_trace_span(response=result)
except Exception as e:
self._end_agent_trace_span(error=e)
enqueue(e)
finally:
enqueue(_stop_event)
stop_reason, message, metrics, state = event["stop"]
result = AgentResult(stop_reason, message, metrics, state)

thread = Thread(target=target_callback, daemon=True)
thread.start()
self._end_agent_trace_span(response=result)

try:
while True:
item = await queue.get()
if item == _stop_event:
break
if isinstance(item, Exception):
raise item
yield item
finally:
thread.join()
except Exception as e:
self._end_agent_trace_span(error=e)
raise

def _run_loop(
self, prompt: str, kwargs: Dict[str, Any], supplementary_callback_handler: Optional[Callable[..., Any]] = None
) -> AgentResult:
self, callback_handler: Callable[..., Any], prompt: str, kwargs: dict[str, Any]
) -> Generator[dict[str, Any], None, None]:
"""Execute the agent's event loop with the given prompt and parameters."""
try:
# If the call had a callback_handler passed in, then for this event_loop
# cycle we call both handlers as the callback_handler
invocation_callback_handler = (
CompositeCallbackHandler(self.callback_handler, supplementary_callback_handler)
if supplementary_callback_handler is not None
else self.callback_handler
)

# Extract key parameters
invocation_callback_handler(init_event_loop=True, **kwargs)
yield {"callback": {"init_event_loop": True, **kwargs}}

# Set up the user message with optional knowledge base retrieval
message_content: List[ContentBlock] = [{"text": prompt}]
message_content: list[ContentBlock] = [{"text": prompt}]
new_message: Message = {"role": "user", "content": message_content}
self.messages.append(new_message)

# Execute the event loop cycle with retry logic for context limits
return self._execute_event_loop_cycle(invocation_callback_handler, kwargs)
yield from self._execute_event_loop_cycle(callback_handler, kwargs)

finally:
self.conversation_manager.apply_management(self)

def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: Dict[str, Any]) -> AgentResult:
def _execute_event_loop_cycle(
self, callback_handler: Callable[..., Any], kwargs: dict[str, Any]
) -> Generator[dict[str, Any], None, None]:
"""Execute the event loop cycle with retry logic for context window limits.

This internal method handles the execution of the event loop cycle and implements
retry logic for handling context window overflow exceptions by reducing the
conversation context and retrying.

Returns:
The result of the event loop cycle.
Yields:
Events of the loop cycle.
"""
# Extract parameters with fallbacks to instance values
system_prompt = kwargs.pop("system_prompt", self.system_prompt)
Expand All @@ -551,7 +527,7 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs

try:
# Execute the main event loop cycle
events = event_loop_cycle(
yield from event_loop_cycle(
model=model,
system_prompt=system_prompt,
messages=messages, # will be modified by event_loop_cycle
Expand All @@ -564,26 +540,18 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs
event_loop_parent_span=self.trace_span,
**kwargs,
)
for event in events:
if "callback" in event:
callback_handler(**event["callback"])

stop_reason, message, metrics, state = event["stop"]

return AgentResult(stop_reason, message, metrics, state)

except ContextWindowOverflowException as e:
# Try reducing the context size and retrying

self.conversation_manager.reduce_context(self, e=e)
return self._execute_event_loop_cycle(callback_handler_override, kwargs)
yield from self._execute_event_loop_cycle(callback_handler_override, kwargs)

def _record_tool_execution(
self,
tool: Dict[str, Any],
tool_result: Dict[str, Any],
tool: dict[str, Any],
tool_result: dict[str, Any],
user_message_override: Optional[str],
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
) -> None:
"""Record a tool execution in the message history.

Expand Down Expand Up @@ -662,7 +630,7 @@ def _end_agent_trace_span(
error: Error to record as a trace attribute.
"""
if self.trace_span:
trace_attributes: Dict[str, Any] = {
trace_attributes: dict[str, Any] = {
"span": self.trace_span,
}

Expand Down
22 changes: 22 additions & 0 deletions tests-integ/test_agent_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

import strands


@pytest.fixture
def agent():
return strands.Agent()


@pytest.mark.asyncio
async def test_stream_async(agent):
stream = agent.stream_async("hello")

exp_message = ""
async for event in stream:
if "event" in event and "contentBlockDelta" in event["event"]:
exp_message += event["event"]["contentBlockDelta"]["delta"]["text"]

tru_message = agent.messages[-1]["content"][0]["text"]

assert tru_message == exp_message
Loading