diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 878cbefcd..402d95c6a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,21 +9,18 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ -import asyncio import json import logging import os import random from concurrent.futures import ThreadPoolExecutor -from threading import Thread -from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union, cast -from uuid import uuid4 +from typing import Any, AsyncIterator, Callable, Generator, Mapping, Optional, Type, TypeVar, Union, cast from opentelemetry import trace from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle -from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler +from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..handlers.tool_handler import AgentToolHandler from ..models.bedrock import BedrockModel from ..telemetry.metrics import EventLoopMetrics @@ -210,7 +207,7 @@ def __init__( self, model: Union[Model, str, None] = None, messages: Optional[Messages] = None, - tools: Optional[List[Union[str, Dict[str, str], Any]]] = None, + tools: Optional[list[Union[str, dict[str, str], Any]]] = None, system_prompt: Optional[str] = None, callback_handler: Optional[ Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] @@ -282,7 +279,7 @@ def __init__( self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager() # Process trace attributes to ensure they're of compatible types - self.trace_attributes: Dict[str, AttributeValue] = {} + self.trace_attributes: dict[str, AttributeValue] = {} if trace_attributes: for k, v in trace_attributes.items(): if isinstance(v, (str, int, float, bool)) or ( @@ -339,7 +336,7 @@ def tool(self) -> ToolCaller: return self.tool_caller @property - def tool_names(self) -> List[str]: + def tool_names(self) -> list[str]: """Get a list of all registered tool names. Returns: @@ -384,19 +381,25 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: - metrics: Performance metrics from the event loop - state: The final state of the event loop """ + callback_handler = kwargs.get("callback_handler", self.callback_handler) + self._start_agent_trace_span(prompt) try: - # Run the event loop and get the result - result = self._run_loop(prompt, kwargs) + events = self._run_loop(callback_handler, prompt, kwargs) + for event in events: + if "callback" in event: + callback_handler(**event["callback"]) + + stop_reason, message, metrics, state = event["stop"] + result = AgentResult(stop_reason, message, metrics, state) self._end_agent_trace_span(response=result) return result + except Exception as e: self._end_agent_trace_span(error=e) - - # Re-raise the exception to preserve original behavior raise def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T: @@ -460,83 +463,56 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: yield event["data"] ``` """ - self._start_agent_trace_span(prompt) + callback_handler = kwargs.get("callback_handler", self.callback_handler) - _stop_event = uuid4() - - queue = asyncio.Queue[Any]() - loop = asyncio.get_event_loop() - - def enqueue(an_item: Any) -> None: - nonlocal queue - nonlocal loop - loop.call_soon_threadsafe(queue.put_nowait, an_item) - - def queuing_callback_handler(**handler_kwargs: Any) -> None: - enqueue(handler_kwargs.copy()) + self._start_agent_trace_span(prompt) - def target_callback() -> None: - nonlocal kwargs + try: + events = self._run_loop(callback_handler, prompt, kwargs) + for event in events: + if "callback" in event: + callback_handler(**event["callback"]) + yield event["callback"] - try: - result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler) - self._end_agent_trace_span(response=result) - except Exception as e: - self._end_agent_trace_span(error=e) - enqueue(e) - finally: - enqueue(_stop_event) + stop_reason, message, metrics, state = event["stop"] + result = AgentResult(stop_reason, message, metrics, state) - thread = Thread(target=target_callback, daemon=True) - thread.start() + self._end_agent_trace_span(response=result) - try: - while True: - item = await queue.get() - if item == _stop_event: - break - if isinstance(item, Exception): - raise item - yield item - finally: - thread.join() + except Exception as e: + self._end_agent_trace_span(error=e) + raise def _run_loop( - self, prompt: str, kwargs: Dict[str, Any], supplementary_callback_handler: Optional[Callable[..., Any]] = None - ) -> AgentResult: + self, callback_handler: Callable[..., Any], prompt: str, kwargs: dict[str, Any] + ) -> Generator[dict[str, Any], None, None]: """Execute the agent's event loop with the given prompt and parameters.""" try: - # If the call had a callback_handler passed in, then for this event_loop - # cycle we call both handlers as the callback_handler - invocation_callback_handler = ( - CompositeCallbackHandler(self.callback_handler, supplementary_callback_handler) - if supplementary_callback_handler is not None - else self.callback_handler - ) - # Extract key parameters - invocation_callback_handler(init_event_loop=True, **kwargs) + yield {"callback": {"init_event_loop": True, **kwargs}} # Set up the user message with optional knowledge base retrieval - message_content: List[ContentBlock] = [{"text": prompt}] + message_content: list[ContentBlock] = [{"text": prompt}] new_message: Message = {"role": "user", "content": message_content} self.messages.append(new_message) # Execute the event loop cycle with retry logic for context limits - return self._execute_event_loop_cycle(invocation_callback_handler, kwargs) + yield from self._execute_event_loop_cycle(callback_handler, kwargs) finally: self.conversation_manager.apply_management(self) - def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: Dict[str, Any]) -> AgentResult: + def _execute_event_loop_cycle( + self, callback_handler: Callable[..., Any], kwargs: dict[str, Any] + ) -> Generator[dict[str, Any], None, None]: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements retry logic for handling context window overflow exceptions by reducing the conversation context and retrying. - Returns: - The result of the event loop cycle. + Yields: + Events of the loop cycle. """ # Extract parameters with fallbacks to instance values system_prompt = kwargs.pop("system_prompt", self.system_prompt) @@ -551,7 +527,7 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs try: # Execute the main event loop cycle - events = event_loop_cycle( + yield from event_loop_cycle( model=model, system_prompt=system_prompt, messages=messages, # will be modified by event_loop_cycle @@ -564,26 +540,18 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs event_loop_parent_span=self.trace_span, **kwargs, ) - for event in events: - if "callback" in event: - callback_handler(**event["callback"]) - - stop_reason, message, metrics, state = event["stop"] - - return AgentResult(stop_reason, message, metrics, state) except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(self, e=e) - return self._execute_event_loop_cycle(callback_handler_override, kwargs) + yield from self._execute_event_loop_cycle(callback_handler_override, kwargs) def _record_tool_execution( self, - tool: Dict[str, Any], - tool_result: Dict[str, Any], + tool: dict[str, Any], + tool_result: dict[str, Any], user_message_override: Optional[str], - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], ) -> None: """Record a tool execution in the message history. @@ -662,7 +630,7 @@ def _end_agent_trace_span( error: Error to record as a trace attribute. """ if self.trace_span: - trace_attributes: Dict[str, Any] = { + trace_attributes: dict[str, Any] = { "span": self.trace_span, } diff --git a/tests-integ/test_agent_async.py b/tests-integ/test_agent_async.py new file mode 100644 index 000000000..597ba13f7 --- /dev/null +++ b/tests-integ/test_agent_async.py @@ -0,0 +1,22 @@ +import pytest + +import strands + + +@pytest.fixture +def agent(): + return strands.Agent() + + +@pytest.mark.asyncio +async def test_stream_async(agent): + stream = agent.stream_async("hello") + + exp_message = "" + async for event in stream: + if "event" in event and "contentBlockDelta" in event["event"]: + exp_message += event["event"]["contentBlockDelta"]["delta"]["text"] + + tru_message = agent.messages[-1]["content"][0]["text"] + + assert tru_message == exp_message diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 2e7639796..6181d7a9e 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2,9 +2,7 @@ import importlib import os import textwrap -import threading import unittest.mock -from time import sleep import pytest from pydantic import BaseModel @@ -914,27 +912,30 @@ async def test_stream_async_returns_all_events(mock_event_loop_cycle): agent = Agent() # Define the side effect to simulate callback handler being called multiple times - def call_callback_handler(*args, **kwargs): - # Extract the callback handler from kwargs - callback_handler = kwargs.get("callback_handler") - # Call the callback handler with different data values - callback_handler(data="First chunk") - callback_handler(data="Second chunk") - callback_handler(data="Final chunk", complete=True) + def test_event_loop(*args, **kwargs): + yield {"callback": {"data": "First chunk"}} + yield {"callback": {"data": "Second chunk"}} + yield {"callback": {"data": "Final chunk", "complete": True}} + # Return expected values from event_loop_cycle yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} - mock_event_loop_cycle.side_effect = call_callback_handler + mock_event_loop_cycle.side_effect = test_event_loop + mock_callback = unittest.mock.Mock() - iterator = agent.stream_async("test message") - actual_events = [e async for e in iterator] + iterator = agent.stream_async("test message", callback_handler=mock_callback) - assert actual_events == [ - {"init_event_loop": True}, + tru_events = [e async for e in iterator] + exp_events = [ + {"init_event_loop": True, "callback_handler": mock_callback}, {"data": "First chunk"}, {"data": "Second chunk"}, {"complete": True, "data": "Final chunk"}, ] + assert tru_events == exp_events + + exp_calls = [unittest.mock.call(**event) for event in exp_events] + mock_callback.assert_has_calls(exp_calls) @pytest.mark.asyncio @@ -982,115 +983,6 @@ async def test_stream_async_raises_exceptions(mock_event_loop_cycle): await anext(iterator) -@pytest.mark.asyncio -async def test_stream_async_can_be_invoked_twice(mock_event_loop_cycle): - """Test that run can be invoked twice with different agents.""" - # Define different responses for the first and second invocations - exp_call_1 = [{"data": "First call - event 1"}, {"data": "First call - event 2", "complete": True}] - exp_call_2 = [{"data": "Second call - event 1"}, {"data": "Second call - event 2", "complete": True}] - - # Set up the mock to handle two different calls - call_count = 0 - - def mock_event_loop_call(**kwargs): - nonlocal call_count - # Extract the callback handler from kwargs - callback_handler = kwargs.get("callback_handler") - events_to_use = exp_call_1 if call_count == 0 else exp_call_2 - call_count += 1 - - for event in events_to_use: - callback_handler(**event) - - # Return expected values from event_loop_cycle - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} - - mock_event_loop_cycle.side_effect = mock_event_loop_call - - agent1 = Agent() - - iter_1 = agent1.stream_async("First prompt") - act_call_1 = [e async for e in iter_1] - assert act_call_1 == [{"init_event_loop": True}, *exp_call_1] - - iter_2 = agent1.stream_async("Second prompt") - act_call_2 = [e async for e in iter_2] - assert act_call_2 == [{"init_event_loop": True}, *exp_call_2] - - # Verify the mock was called twice - assert call_count == 2 - assert mock_event_loop_cycle.call_count == 2 - - # Verify the correct arguments were passed to event_loop_cycle - # First call - args1, kwargs1 = mock_event_loop_cycle.call_args_list[0] - assert kwargs1.get("model") == agent1.model - assert kwargs1.get("system_prompt") == agent1.system_prompt - assert kwargs1.get("messages") == agent1.messages - assert kwargs1.get("tool_config") == agent1.tool_config - assert "callback_handler" in kwargs1 - - # Second call - args2, kwargs2 = mock_event_loop_cycle.call_args_list[1] - assert kwargs2.get("model") == agent1.model - assert kwargs2.get("system_prompt") == agent1.system_prompt - assert kwargs2.get("messages") == agent1.messages - assert kwargs2.get("tool_config") == agent1.tool_config - assert "callback_handler" in kwargs2 - - -@pytest.mark.asyncio -async def test_run_non_blocking_behavior(mock_event_loop_cycle): - """Test that when one thread is blocked in run, other threads can continue execution.""" - - # This event will be used to signal when the first thread has started - unblock_background_thread = threading.Event() - is_blocked = False - - # Define a side effect that blocks until explicitly allowed to continue - def blocking_call(**kwargs): - nonlocal is_blocked - # Extract the callback handler from kwargs - callback_handler = kwargs.get("callback_handler") - callback_handler(data="First event") - is_blocked = True - unblock_background_thread.wait(timeout=5.0) - is_blocked = False - callback_handler(data="Last event", complete=True) - # Return expected values from event_loop_cycle - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} - - mock_event_loop_cycle.side_effect = blocking_call - - # Create and start the background thread - agent = Agent() - iterator = agent.stream_async("This will block") - - # Ensure it emits the first event - assert await anext(iterator) == {"init_event_loop": True} - assert await anext(iterator) == {"data": "First event"} - - retry_count = 0 - while not is_blocked and retry_count < 10: - sleep(1) - retry_count += 1 - assert is_blocked - - # Ensure it emits the next event - unblock_background_thread.set() - assert await anext(iterator) == {"data": "Last event", "complete": True} - - retry_count = 0 - while is_blocked and retry_count < 10: - sleep(1) - retry_count += 1 - assert not is_blocked - - # Ensure the iterator is exhausted - remaining = [it async for it in iterator] - assert len(remaining) == 0 - - def test_agent_init_with_trace_attributes(): """Test that trace attributes are properly initialized in the Agent.""" # Test with valid trace attributes