diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py index af8c4e1e1..8e8f121dd 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py @@ -1,15 +1,9 @@ """Event loop management for bidirectional streaming.""" from .bidirectional_event_loop import ( - BidirectionalConnection, - bidirectional_event_loop_cycle, - start_bidirectional_connection, - stop_bidirectional_connection, + BidirectionalAgentLoop, ) __all__ = [ - "BidirectionalConnection", - "start_bidirectional_connection", - "stop_bidirectional_connection", - "bidirectional_event_loop_cycle", + "BidirectionalAgentLoop", ] diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 38d92aea8..1de91b4c2 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -1,40 +1,29 @@ -"""Bidirectional session management for concurrent streaming conversations. - -Manages bidirectional communication sessions with concurrent processing of model events, -tool execution, and audio processing. Provides coordination between background tasks -while maintaining a simple interface for agent interaction. - -Features: -- Concurrent task management for model events and tool execution -- Interruption handling with audio buffer clearing -- Tool execution with cancellation support -- Session lifecycle management -""" +"""Class-based event loop for real-time bidirectional streaming with concurrent tool execution.""" import asyncio import logging -import traceback import uuid +from typing import TYPE_CHECKING, Dict, List -from ....tools._validator import validate_and_prepare_tools +from ....hooks import MessageAddedEvent from ....telemetry.metrics import Trace +from ....tools._validator import validate_and_prepare_tools from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModel +if TYPE_CHECKING: + from ..agent import BidirectionalAgent + logger = logging.getLogger(__name__) -# Session constants -TOOL_QUEUE_TIMEOUT = 0.5 -SUPERVISION_INTERVAL = 0.1 +class BidirectionalAgentLoop: + """Agent loop coordinator for bidirectional streaming sessions. -class BidirectionalConnection: - """Session wrapper for bidirectional communication with concurrent task management. - - Coordinates background tasks for model event processing, tool execution, and audio - handling while providing a simple interface for agent interactions. + Manages model event processing with tool execution. + Tool execution uses asyncio.create_task() for concurrent execution. """ def __init__(self, model: BidirectionalModel, agent: "BidirectionalAgent") -> None: @@ -48,433 +37,274 @@ def __init__(self, model: BidirectionalModel, agent: "BidirectionalAgent") -> No self.agent = agent self.active = True - # Background processing coordination - self.background_tasks = [] - self.tool_queue = asyncio.Queue() - self.audio_output_queue = asyncio.Queue() - - # Task management for cleanup - self.pending_tool_tasks: dict[str, asyncio.Task] = {} + # Task tracking + self.background_tasks: List[asyncio.Task] = [] + self.pending_tool_tasks: Dict[str, asyncio.Task] = {} - # Interruption handling (model-agnostic) + # Synchronization primitives self.interrupted = False self.interruption_lock = asyncio.Lock() - - # Tool execution tracking - self.tool_count = 0 - - -async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: - """Initialize bidirectional session with conycurrent background tasks. - - Creates a model-specific session and starts background tasks for processing - model events, executing tools, and managing the session lifecycle. - - Args: - agent: BidirectionalAgent instance. - - Returns: - BidirectionalConnection: Active session with background tasks running. - """ - logger.debug("Starting bidirectional session - initializing model connection") - - # Connect to model - await agent.model.connect( - system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages - ) - - # Create connection wrapper for background processing - session = BidirectionalConnection(model=agent.model, agent=agent) - - # Start concurrent background processors IMMEDIATELY after session creation - # This is critical - Nova Sonic needs response processing during initialization - logger.debug("Starting background processors for concurrent processing") - session.background_tasks = [ - asyncio.create_task(_process_model_events(session)), # Handle model responses - asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently - ] + self.conversation_lock = asyncio.Lock() - # Start main coordination cycle - session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) + # Output queue and metrics + self.event_output_queue = asyncio.Queue() + self.tool_count = 0 - logger.debug("Session ready with %d background tasks", len(session.background_tasks)) - return session + logger.debug("BidirectionalAgentLoop initialized") + async def start(self) -> None: + """Start background task for model event processing.""" + logger.debug("Starting bidirectional agent loop") -async def stop_bidirectional_connection(session: BidirectionalConnection) -> None: - """End session and cleanup resources including background tasks. + self.background_tasks = [asyncio.create_task(self._process_model_events())] - Args: - session: BidirectionalConnection to cleanup. - """ - if not session.active: - return + logger.debug("Agent loop started") - logger.debug("Session cleanup starting") - session.active = False + async def stop(self) -> None: + """Gracefully shutdown and cleanup all resources.""" + if not self.active: + return - # Cancel pending tool tasks - for _, task in session.pending_tool_tasks.items(): - if not task.done(): - task.cancel() + logger.debug("Stopping bidirectional agent loop") + self.active = False - # Cancel background tasks - for task in session.background_tasks: - if not task.done(): - task.cancel() + # Cancel pending tool tasks + for task in self.pending_tool_tasks.values(): + if not task.done(): + task.cancel() - # Cancel main cycle task - if hasattr(session, "main_cycle_task") and not session.main_cycle_task.done(): - session.main_cycle_task.cancel() + # Cancel background task + for task in self.background_tasks: + if not task.done(): + task.cancel() - # Wait for tasks to complete - all_tasks = session.background_tasks + list(session.pending_tool_tasks.values()) - if hasattr(session, "main_cycle_task"): - all_tasks.append(session.main_cycle_task) + # Wait for cancellations + all_tasks = list(self.pending_tool_tasks.values()) + self.background_tasks + if all_tasks: + await asyncio.gather(*all_tasks, return_exceptions=True) - if all_tasks: - await asyncio.gather(*all_tasks, return_exceptions=True) + # Close model session + try: + await self.model_session.close() + except Exception as e: + logger.warning("Error closing model session: %s", e) - # Close model connection - await session.model.close() - logger.debug("Connection closed") + logger.debug("Agent loop stopped - tools executed: %d", self.tool_count) + def schedule_tool_execution(self, tool_use: ToolUse) -> None: + """Create asyncio task for tool execution and add to pending tasks tracking.""" + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") -async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: - """Main event loop coordinator that runs continuously during the session. + current_tool_number = self.tool_count + 1 + self.tool_count = current_tool_number + print(f"\nTool #{current_tool_number}: {tool_name}") - Monitors background tasks, manages session state, and handles session lifecycle. - Provides supervision for concurrent model event processing and tool execution. + logger.debug("Scheduling tool execution: %s (id: %s)", tool_name, tool_id) - Args: - session: BidirectionalConnection to coordinate. - """ - while session.active: - try: - # Check if background processors are still running - if all(task.done() for task in session.background_tasks): - logger.debug("Session end - all processors completed") - session.active = False - break - - # Check for failed background tasks - for i, task in enumerate(session.background_tasks): - if task.done() and not task.cancelled(): - exception = task.exception() - if exception: - logger.error("Session error in processor %d: %s", i, str(exception)) - session.active = False - raise exception - - # Brief pause before next supervision check - await asyncio.sleep(SUPERVISION_INTERVAL) + # Create task with UUID tracking + task_id = str(uuid.uuid4()) + task = asyncio.create_task(self._execute_tool(tool_use)) + self.pending_tool_tasks[task_id] = task - except asyncio.CancelledError: - break - except Exception as e: - logger.error("Event loop error: %s", str(e)) - session.active = False - raise + def cleanup_task(completed_task: asyncio.Task) -> None: + self.pending_tool_tasks.pop(task_id, None) + if completed_task.cancelled(): + logger.debug("Tool task cancelled: %s", task_id) + elif completed_task.exception(): + logger.error("Tool task error: %s - %s", task_id, completed_task.exception()) + else: + logger.debug("Tool task completed: %s", task_id) + task.add_done_callback(cleanup_task) -async def _handle_interruption(session: BidirectionalConnection) -> None: - """Handle interruption detection with task cancellation and audio buffer clearing. + async def handle_interruption(self) -> None: + """Execute atomic interruption handling with race condition prevention. - Cancels pending tool tasks and clears audio output queues to ensure responsive - interruption handling during conversations. Protected by async lock to prevent - concurrent execution and race conditions. + Clears audio buffers and checks tool execution status. + Does not cancel running tools - only clears audio output queues. + """ + async with self.interruption_lock: + if self.interrupted: + logger.debug("Interruption already in progress") + return - Args: - session: BidirectionalConnection to handle interruption for. - """ - async with session.interruption_lock: - # If already interrupted, skip duplicate processing - if session.interrupted: - logger.debug("Interruption already in progress") - return + logger.debug("Interruption detected") + self.interrupted = True - logger.debug("Interruption detected") - session.interrupted = True + # Check if tools are currently executing + active_tool_tasks = [task for task in self.pending_tool_tasks.values() if not task.done()] - # Cancel all pending tool execution tasks - cancelled_tools = 0 - for _task_id, task in list(session.pending_tool_tasks.items()): - if not task.done(): - task.cancel() - cancelled_tools += 1 - logger.debug("Tool task cancelled: %s", _task_id) + if active_tool_tasks: + logger.debug("Tools are protected - %d tools currently executing", len(active_tool_tasks)) + else: + logger.debug("No active tools - full interruption handling") - if cancelled_tools > 0: - logger.debug("Tool tasks cancelled: %d", cancelled_tools) + # Clear output queue + cleared_count = 0 + while True: + try: + self.event_output_queue.get_nowait() + cleared_count += 1 + except asyncio.QueueEmpty: + break - # Clear all queued audio output events - cleared_count = 0 - while True: + # Filter audio events from agent queue, preserve others + temp_events = [] try: - session.audio_output_queue.get_nowait() - cleared_count += 1 + while True: + event = self.agent._output_queue.get_nowait() + if event.get("audioOutput"): + cleared_count += 1 + else: + temp_events.append(event) except asyncio.QueueEmpty: - break - - # Also clear the agent's audio output queue - audio_cleared = 0 - # Create a temporary list to hold non-audio events - temp_events = [] - try: - while True: - event = session.agent._output_queue.get_nowait() - if event.get("audioOutput"): - audio_cleared += 1 - else: - # Keep non-audio events - temp_events.append(event) - except asyncio.QueueEmpty: - pass - - # Put back non-audio events - for event in temp_events: - session.agent._output_queue.put_nowait(event) - - if audio_cleared > 0: - logger.debug("Agent audio queue cleared: %d events", audio_cleared) - - if cleared_count > 0: - logger.debug("Session audio queue cleared: %d events", cleared_count) + pass - # Reset interruption flag after clearing (automatic recovery) - session.interrupted = False - logger.debug("Interruption handled - tools cancelled: %d, audio cleared: %d", cancelled_tools, cleared_count) + # Restore non-audio events + for event in temp_events: + self.agent._output_queue.put_nowait(event) + self.interrupted = False -async def _process_model_events(session: BidirectionalConnection) -> None: - """Process model events and convert them to Strands format. + if active_tool_tasks: + logger.debug("Interruption handled (tools protected) - audio cleared: %d", cleared_count) + else: + logger.debug("Interruption handled (full) - audio cleared: %d", cleared_count) - Background task that handles all model responses, converts provider-specific - events to standardized formats, and manages interruption detection. + async def _process_model_events(self) -> None: + """Process incoming provider event stream and dispatch events to handlers.""" + logger.debug("Agent loop processor started") - Args: - session: BidirectionalConnection containing model. - """ - logger.debug("Model events processor started") - try: - async for provider_event in session.model.receive(): - if not session.active: - break - - # Basic validation - skip invalid events - if not isinstance(provider_event, dict): - continue - - strands_event = provider_event - - # Handle interruption detection (provider converts raw patterns to interruptionDetected) - if strands_event.get("interruptionDetected"): - logger.debug("Interruption forwarded") - await _handle_interruption(session) - # Forward interruption event to agent for application-level handling - await session.agent._output_queue.put(strands_event) - continue - - # Queue tool requests for concurrent execution - if strands_event.get("toolUse"): - tool_name = strands_event["toolUse"].get("name") - logger.debug("Tool usage detected: %s", tool_name) - await session.tool_queue.put(strands_event["toolUse"]) - continue - - # Send output events to Agent for receive() method - if strands_event.get("audioOutput") or strands_event.get("textOutput"): - await session.agent._output_queue.put(strands_event) - - # Update Agent conversation history using existing patterns - if strands_event.get("messageStop"): - logger.debug("Message added to history") - session.agent.messages.append(strands_event["messageStop"]["message"]) - - # Handle user audio transcripts - add to message history - if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": - user_transcript = strands_event["textOutput"]["text"] - if user_transcript.strip(): # Only add non-empty transcripts - user_message = {"role": "user", "content": user_transcript} - session.agent.messages.append(user_message) - logger.debug("User transcript added to history") - - except Exception as e: - logger.error("Model events error: %s", str(e)) - traceback.print_exc() - finally: - logger.debug("Model events processor stopped") - - -async def _process_tool_execution(session: BidirectionalConnection) -> None: - """Execute tools concurrently with interruption support. - - Background task that manages tool execution without blocking model event - processing or user interaction. Uses proper asyncio cancellation for - interruption handling rather than manual state checks. - - Args: - session: BidirectionalConnection containing tool queue. - """ - logger.debug("Tool execution processor started") - while session.active: try: - tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) - tool_name = tool_use.get("name") - tool_id = tool_use.get("toolUseId") - - session.tool_count += 1 - print(f"\nTool #{session.tool_count}: {tool_name}") - - logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id) - - task_id = str(uuid.uuid4()) - task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) - session.pending_tool_tasks[task_id] = task - - def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: - try: - # Remove from pending tasks - if task_id in session.pending_tool_tasks: - del session.pending_tool_tasks[task_id] - - # Log completion status - if completed_task.cancelled(): - logger.debug("Tool task cancelled: %s", task_id) - elif completed_task.exception(): - logger.error("Tool task error: %s - %s", task_id, str(completed_task.exception())) - else: - logger.debug("Tool task completed: %s", task_id) - except Exception as e: - logger.error("Tool task cleanup failed: %s - %s", task_id, str(e)) - - task.add_done_callback(cleanup_task) + async for provider_event in self.model_session.receive_events(): + if not self.active: + break + + if not isinstance(provider_event, dict): + continue + + strands_event = provider_event + + # Handle interruptions + if strands_event.get("interruptionDetected"): + logger.debug("Interruption detected from model") + await self.handle_interruption() + await self.agent._output_queue.put(strands_event) + continue + + # Schedule tool execution + if strands_event.get("toolUse"): + tool_name = strands_event["toolUse"].get("name") + logger.debug("Tool request received: %s", tool_name) + self.schedule_tool_execution(strands_event["toolUse"]) + continue + + # Route audio to both queues + if strands_event.get("audioOutput"): + await self.event_output_queue.put(strands_event) + await self.agent._output_queue.put(strands_event) + continue + + # Forward text output + if strands_event.get("textOutput"): + await self.agent._output_queue.put(strands_event) + + # Update conversation history + if strands_event.get("messageStop"): + logger.debug("Adding message to conversation history") + async with self.conversation_lock: + self.agent.messages.append(strands_event["messageStop"]["message"]) + + # Handle user transcripts + if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": + user_transcript = strands_event["textOutput"]["text"] + if user_transcript.strip(): + user_message = {"role": "user", "content": user_transcript} + async with self.conversation_lock: + self.agent.messages.append(user_message) + logger.debug("User transcript added to history") - except asyncio.TimeoutError: - if not session.active: - break - # Remove completed tasks from tracking - completed_tasks = [task_id for task_id, task in session.pending_tool_tasks.items() if task.done()] - for task_id in completed_tasks: - if task_id in session.pending_tool_tasks: - del session.pending_tool_tasks[task_id] - - if completed_tasks: - logger.debug("Periodic task cleanup: %d tasks", len(completed_tasks)) - - continue except Exception as e: - logger.error("Tool execution error: %s", str(e)) - if not session.active: - break + logger.error("Agent loop processor failed: %s", e) + self.active = False + finally: + logger.debug("Agent loop processor stopped") - logger.debug("Tool execution processor stopped") + async def _execute_tool(self, tool_use: ToolUse) -> None: + """Execute tool using Strands validation and execution pipeline.""" + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") + logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id) + try: + # Prepare for Strands validation system + tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} + tool_uses: list[ToolUse] = [] + tool_results: list[ToolResult] = [] + invalid_tool_use_ids: list[str] = [] + + # Validate tools + validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) + valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] + + if not valid_tool_uses: + logger.warning("No valid tools after validation: %s", tool_name) + return + + # Execute with agent context + invocation_state = { + "agent": self.agent, + "model": self.agent.model, + "messages": self.agent.messages, + "system_prompt": self.agent.system_prompt, + } + cycle_trace = Trace("Bidirectional Tool Execution") + tool_events = self.agent.tool_executor._execute( + self.agent, valid_tool_uses, tool_results, cycle_trace, None, invocation_state + ) + + # Process tool event stream + async for tool_event in tool_events: + if isinstance(tool_event, ToolResultEvent): + tool_result = tool_event.tool_result + tool_use_id = tool_result.get("toolUseId") + await self.model_session.send_tool_result(tool_use_id, tool_result) + logger.debug("Tool result sent: %s", tool_use_id) + elif isinstance(tool_event, ToolStreamEvent): + logger.debug("Tool stream event: %s", tool_event) + + # Update conversation history using conversation lock + if tool_results: + tool_result_message: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + + async with self.conversation_lock: + self.agent.messages.append(tool_result_message) + self.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=self.agent, message=tool_result_message)) + logger.debug("Tool result message added to history: %s", tool_name) + + logger.debug("Tool execution completed: %s", tool_name) + except asyncio.CancelledError: + logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id) + raise + except Exception as e: + logger.error("Tool execution error: %s - %s", tool_name, e) -async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: - """Execute tool using the complete Strands tool execution system. - - Uses proper Strands ToolExecutor system with validation, error handling, - and event streaming. - - Args: - session: BidirectionalConnection for context. - tool_use: Tool use event to execute. - """ - tool_name = tool_use.get("name") - tool_id = tool_use.get("toolUseId") - - logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id) - - try: - # Create message structure for validation - tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} - - # Use Strands validation system - tool_uses: list[ToolUse] = [] - tool_results: list[ToolResult] = [] - invalid_tool_use_ids: list[str] = [] - - validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) - - # Filter valid tools - valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] - - if not valid_tool_uses: - logger.warning("No valid tools after validation: %s", tool_name) - return - - # Create invocation state for tool execution - invocation_state = { - "agent": session.agent, - "model": session.agent.model, - "messages": session.agent.messages, - "system_prompt": session.agent.system_prompt, - } - - # Create cycle trace and span - cycle_trace = Trace("Bidirectional Tool Execution") - cycle_span = None - - tool_events = session.agent.tool_executor._execute( - session.agent, - valid_tool_uses, - tool_results, - cycle_trace, - cycle_span, - invocation_state - ) - - # Process tool events and send results to provider - async for tool_event in tool_events: - if isinstance(tool_event, ToolResultEvent): - tool_result = tool_event.tool_result - tool_use_id = tool_result.get("toolUseId") - - # Send result through send() method - await session.model.send(tool_result) - logger.debug("Tool result sent: %s", tool_use_id) - - # Handle streaming events if needed later - elif isinstance(tool_event, ToolStreamEvent): - logger.debug("Tool stream event: %s", tool_event) - pass - - # Add tool result message to conversation history - if tool_results: - from ....hooks import MessageAddedEvent - - tool_result_message: Message = { - "role": "user", - "content": [{"toolResult": result} for result in tool_results], + # Send error result to provider + error_result: ToolResult = { + "toolUseId": tool_id, + "status": "error", + "content": [{"text": f"Error: {str(e)}"}], } - - session.agent.messages.append(tool_result_message) - session.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=session.agent, message=tool_result_message)) - logger.debug("Tool result message added to history: %s", tool_name) - - logger.debug("Tool execution completed: %s", tool_name) - - except asyncio.CancelledError: - logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id) - raise - except Exception as e: - logger.error("Tool execution error: %s - %s", tool_name, str(e)) - - # Send error result - error_result: ToolResult = { - "toolUseId": tool_id, - "status": "error", - "content": [{"text": f"Error: {str(e)}"}] - } - try: - await session.model.send(error_result) - logger.debug("Error result sent: %s", tool_id) - except Exception as send_error: - logger.error("Failed to send error result: %s - %s", tool_id, str(send_error)) - raise # Propagate exception since this is experimental code - + try: + await self.model_session.send_tool_result(tool_id, error_result) + logger.debug("Error result sent: %s", tool_id) + except Exception: + logger.error("Failed to send error result: %s", tool_id)