From 7dff9763713390eadfbf2e5027ecc1eef708cc19 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 21 Oct 2025 11:23:09 -0400 Subject: [PATCH 1/5] (feat): Improve bidi event loop --- .../event_loop/bidirectional_event_loop.py | 743 ++++++++---------- 1 file changed, 316 insertions(+), 427 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 69f5d759d..cd29c9e0e 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -1,481 +1,370 @@ -"""Bidirectional session management for concurrent streaming conversations. - -Manages bidirectional communication sessions with concurrent processing of model events, -tool execution, and audio processing. Provides coordination between background tasks -while maintaining a simple interface for agent interaction. - -Features: -- Concurrent task management for model events and tool execution -- Interruption handling with audio buffer clearing -- Tool execution with cancellation support -- Session lifecycle management -""" +"""Class-based event loop for real-time bidirectional streaming with concurrent tool execution.""" import asyncio import logging import traceback import uuid +from typing import TYPE_CHECKING, Optional, Dict, List from ....tools._validator import validate_and_prepare_tools from ....telemetry.metrics import Trace from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse +from ....hooks import MessageAddedEvent from ..models.bidirectional_model import BidirectionalModelSession +if TYPE_CHECKING: + from ..agent import BidirectionalAgent logger = logging.getLogger(__name__) -# Session constants -TOOL_QUEUE_TIMEOUT = 0.5 -SUPERVISION_INTERVAL = 0.1 - -class BidirectionalConnection: - """Session wrapper for bidirectional communication with concurrent task management. - - Coordinates background tasks for model event processing, tool execution, and audio - handling while providing a simple interface for agent interactions. +class BidirectionalEventLoop: + """Event loop coordinator for bidirectional streaming sessions. + + Manages concurrent background tasks for model event processing and session supervision. + Tool execution uses immediate asyncio.Task creation (0ms scheduling) rather than polling. + Provides atomic interruption handling and race condition prevention. """ - def __init__(self, model_session: BidirectionalModelSession, agent: "BidirectionalAgent") -> None: - """Initialize session with model session and agent reference. - - Args: - model_session: Provider-specific bidirectional model session. - agent: BidirectionalAgent instance for tool registry access. - """ + def __init__(self, model_session: BidirectionalModelSession, agent: "BidirectionalAgent"): + """Initialize event loop with model session and agent dependencies.""" self.model_session = model_session self.agent = agent self.active = True - - # Background processing coordination - self.background_tasks = [] - self.tool_queue = asyncio.Queue() - self.audio_output_queue = asyncio.Queue() - - # Task management for cleanup - self.pending_tool_tasks: dict[str, asyncio.Task] = {} - - # Interruption handling (model-agnostic) + + # Task tracking + self.background_tasks: List[asyncio.Task] = [] + self.pending_tool_tasks: Dict[str, asyncio.Task] = {} + + # Synchronization primitives self.interrupted = False self.interruption_lock = asyncio.Lock() + self.conversation_lock = asyncio.Lock() # Race condition prevention - # Tool execution tracking + # Audio and metrics + self.audio_output_queue = asyncio.Queue() self.tool_count = 0 + + logger.debug("BidirectionalEventLoop initialized") + async def start(self) -> None: + """Start background tasks for model event processing and session supervision.""" + logger.debug("Starting bidirectional event loop") + + self.background_tasks = [ + asyncio.create_task(self._process_model_events()), + asyncio.create_task(self._supervise_session()), + ] + + logger.debug("Event loop started with %d background tasks", len(self.background_tasks)) -async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: - """Initialize bidirectional session with conycurrent background tasks. - - Creates a model-specific session and starts background tasks for processing - model events, executing tools, and managing the session lifecycle. - - Args: - agent: BidirectionalAgent instance. - - Returns: - BidirectionalConnection: Active session with background tasks running. - """ - logger.debug("Starting bidirectional session - initializing model session") - - # Create provider-specific session - model_session = await agent.model.create_bidirectional_connection( - system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages - ) - - # Create session wrapper for background processing - session = BidirectionalConnection(model_session=model_session, agent=agent) - - # Start concurrent background processors IMMEDIATELY after session creation - # This is critical - Nova Sonic needs response processing during initialization - logger.debug("Starting background processors for concurrent processing") - session.background_tasks = [ - asyncio.create_task(_process_model_events(session)), # Handle model responses - asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently - ] - - # Start main coordination cycle - session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) - - logger.debug("Session ready with %d background tasks", len(session.background_tasks)) - return session - - -async def stop_bidirectional_connection(session: BidirectionalConnection) -> None: - """End session and cleanup resources including background tasks. - - Args: - session: BidirectionalConnection to cleanup. - """ - if not session.active: - return - - logger.debug("Session cleanup starting") - session.active = False - - # Cancel pending tool tasks - for _, task in session.pending_tool_tasks.items(): - if not task.done(): - task.cancel() - - # Cancel background tasks - for task in session.background_tasks: - if not task.done(): - task.cancel() - - # Cancel main cycle task - if hasattr(session, "main_cycle_task") and not session.main_cycle_task.done(): - session.main_cycle_task.cancel() - - # Wait for tasks to complete - all_tasks = session.background_tasks + list(session.pending_tool_tasks.values()) - if hasattr(session, "main_cycle_task"): - all_tasks.append(session.main_cycle_task) - - if all_tasks: - await asyncio.gather(*all_tasks, return_exceptions=True) - - # Close model session - await session.model_session.close() - logger.debug("Session closed") - - -async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: - """Main event loop coordinator that runs continuously during the session. - - Monitors background tasks, manages session state, and handles session lifecycle. - Provides supervision for concurrent model event processing and tool execution. - - Args: - session: BidirectionalConnection to coordinate. - """ - while session.active: - try: - # Check if background processors are still running - if all(task.done() for task in session.background_tasks): - logger.debug("Session end - all processors completed") - session.active = False - break - - # Check for failed background tasks - for i, task in enumerate(session.background_tasks): - if task.done() and not task.cancelled(): - exception = task.exception() - if exception: - logger.error("Session error in processor %d: %s", i, str(exception)) - session.active = False - raise exception - - # Brief pause before next supervision check - await asyncio.sleep(SUPERVISION_INTERVAL) - - except asyncio.CancelledError: - break - except Exception as e: - logger.error("Event loop error: %s", str(e)) - session.active = False - raise - - -async def _handle_interruption(session: BidirectionalConnection) -> None: - """Handle interruption detection with task cancellation and audio buffer clearing. - - Cancels pending tool tasks and clears audio output queues to ensure responsive - interruption handling during conversations. Protected by async lock to prevent - concurrent execution and race conditions. - - Args: - session: BidirectionalConnection to handle interruption for. - """ - async with session.interruption_lock: - # If already interrupted, skip duplicate processing - if session.interrupted: - logger.debug("Interruption already in progress") + async def stop(self) -> None: + """Gracefully shutdown and cleanup all resources.""" + if not self.active: return - - logger.debug("Interruption detected") - session.interrupted = True - - # Cancel all pending tool execution tasks - cancelled_tools = 0 - for _task_id, task in list(session.pending_tool_tasks.items()): + + logger.debug("Stopping bidirectional event loop") + self.active = False + + # Cancel all tasks + for task in self.pending_tool_tasks.values(): if not task.done(): task.cancel() - cancelled_tools += 1 - logger.debug("Tool task cancelled: %s", _task_id) - - if cancelled_tools > 0: - logger.debug("Tool tasks cancelled: %d", cancelled_tools) - - # Clear all queued audio output events - cleared_count = 0 - while True: - try: - session.audio_output_queue.get_nowait() - cleared_count += 1 - except asyncio.QueueEmpty: - break - - # Also clear the agent's audio output queue - audio_cleared = 0 - # Create a temporary list to hold non-audio events - temp_events = [] + + for task in self.background_tasks: + if not task.done(): + task.cancel() + + # Wait for cancellations + all_tasks = list(self.pending_tool_tasks.values()) + self.background_tasks + if all_tasks: + await asyncio.gather(*all_tasks, return_exceptions=True) + + # Close model session try: - while True: - event = session.agent._output_queue.get_nowait() - if event.get("audioOutput"): - audio_cleared += 1 - else: - # Keep non-audio events - temp_events.append(event) - except asyncio.QueueEmpty: - pass - - # Put back non-audio events - for event in temp_events: - session.agent._output_queue.put_nowait(event) - - if audio_cleared > 0: - logger.debug("Agent audio queue cleared: %d events", audio_cleared) - - if cleared_count > 0: - logger.debug("Session audio queue cleared: %d events", cleared_count) - - # Reset interruption flag after clearing (automatic recovery) - session.interrupted = False - logger.debug("Interruption handled - tools cancelled: %d, audio cleared: %d", cancelled_tools, cleared_count) - - -async def _process_model_events(session: BidirectionalConnection) -> None: - """Process model events and convert them to Strands format. - - Background task that handles all model responses, converts provider-specific - events to standardized formats, and manages interruption detection. - - Args: - session: BidirectionalConnection containing model session. - """ - logger.debug("Model events processor started") - try: - async for provider_event in session.model_session.receive_events(): - if not session.active: - break - - # Basic validation - skip invalid events - if not isinstance(provider_event, dict): - continue - - strands_event = provider_event - - # Handle interruption detection (provider converts raw patterns to interruptionDetected) - if strands_event.get("interruptionDetected"): - logger.debug("Interruption forwarded") - await _handle_interruption(session) - # Forward interruption event to agent for application-level handling - await session.agent._output_queue.put(strands_event) - continue - - # Queue tool requests for concurrent execution - if strands_event.get("toolUse"): - tool_name = strands_event["toolUse"].get("name") - logger.debug("Tool usage detected: %s", tool_name) - await session.tool_queue.put(strands_event["toolUse"]) - continue - - # Send output events to Agent for receive() method - if strands_event.get("audioOutput") or strands_event.get("textOutput"): - await session.agent._output_queue.put(strands_event) - - # Update Agent conversation history using existing patterns - if strands_event.get("messageStop"): - logger.debug("Message added to history") - session.agent.messages.append(strands_event["messageStop"]["message"]) + await self.model_session.close() + except Exception as e: + logger.warning("Error closing model session: %s", e) - # Handle user audio transcripts - add to message history - if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": - user_transcript = strands_event["textOutput"]["text"] - if user_transcript.strip(): # Only add non-empty transcripts - user_message = {"role": "user", "content": user_transcript} - session.agent.messages.append(user_message) - logger.debug("User transcript added to history") - - except Exception as e: - logger.error("Model events error: %s", str(e)) - traceback.print_exc() - finally: - logger.debug("Model events processor stopped") + logger.debug("Event loop stopped - tools executed: %d", self.tool_count) + def schedule_tool_execution(self, tool_use: ToolUse) -> None: + """Create asyncio task for immediate tool execution (0ms scheduling).""" + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") + + # Thread-safe counter increment + current_tool_number = self.tool_count + 1 + self.tool_count = current_tool_number + print(f"\nTool #{current_tool_number}: {tool_name}") + + logger.debug("Scheduling tool execution: %s (id: %s)", tool_name, tool_id) + + # Create task with UUID tracking + task_id = str(uuid.uuid4()) + task = asyncio.create_task(self._execute_tool_with_strands(tool_use)) + self.pending_tool_tasks[task_id] = task + + def cleanup_task(completed_task: asyncio.Task) -> None: + self.pending_tool_tasks.pop(task_id, None) + if completed_task.cancelled(): + logger.debug("Tool task cancelled: %s", task_id) + elif completed_task.exception(): + logger.error("Tool task error: %s - %s", task_id, completed_task.exception()) + else: + logger.debug("Tool task completed: %s", task_id) + + task.add_done_callback(cleanup_task) -async def _process_tool_execution(session: BidirectionalConnection) -> None: - """Execute tools concurrently with interruption support. + async def handle_interruption(self) -> None: + """Execute atomic interruption handling with race condition prevention. + + Always clears audio buffers for responsive interruption. + Protects tool execution by not cancelling tools when they are running. + """ + async with self.interruption_lock: + if self.interrupted: + logger.debug("Interruption already in progress") + return - Background task that manages tool execution without blocking model event - processing or user interaction. Uses proper asyncio cancellation for - interruption handling rather than manual state checks. + logger.debug("Interruption detected") + self.interrupted = True - Args: - session: BidirectionalConnection containing tool queue. - """ - logger.debug("Tool execution processor started") - while session.active: - try: - tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) - tool_name = tool_use.get("name") - tool_id = tool_use.get("toolUseId") - - session.tool_count += 1 - print(f"\nTool #{session.tool_count}: {tool_name}") + # Check if tools are currently executing + active_tool_tasks = [task for task in self.pending_tool_tasks.values() if not task.done()] - logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id) - - task_id = str(uuid.uuid4()) - task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) - session.pending_tool_tasks[task_id] = task - - def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: + if active_tool_tasks: + logger.debug("Tools are protected - %d tools currently executing", len(active_tool_tasks)) + # Don't cancel tools, but still clear audio for responsive interruption + else: + logger.debug("No active tools - full interruption handling") + + # Always clear audio queues for responsive interruption (regardless of tool status) + cleared_count = 0 + while True: try: - # Remove from pending tasks - if task_id in session.pending_tool_tasks: - del session.pending_tool_tasks[task_id] + self.audio_output_queue.get_nowait() + cleared_count += 1 + except asyncio.QueueEmpty: + break - # Log completion status - if completed_task.cancelled(): - logger.debug("Tool task cancelled: %s", task_id) - elif completed_task.exception(): - logger.error("Tool task error: %s - %s", task_id, str(completed_task.exception())) + # Filter audio events from agent queue, preserve others + temp_events = [] + try: + while True: + event = self.agent._output_queue.get_nowait() + if event.get("audioOutput"): + cleared_count += 1 else: - logger.debug("Tool task completed: %s", task_id) - except Exception as e: - logger.error("Tool task cleanup failed: %s - %s", task_id, str(e)) - - task.add_done_callback(cleanup_task) - - except asyncio.TimeoutError: - if not session.active: - break - # Remove completed tasks from tracking - completed_tasks = [task_id for task_id, task in session.pending_tool_tasks.items() if task.done()] - for task_id in completed_tasks: - if task_id in session.pending_tool_tasks: - del session.pending_tool_tasks[task_id] - - if completed_tasks: - logger.debug("Periodic task cleanup: %d tasks", len(completed_tasks)) - - continue - except Exception as e: - logger.error("Tool execution error: %s", str(e)) - if not session.active: - break - - logger.debug("Tool execution processor stopped") - - - + temp_events.append(event) + except asyncio.QueueEmpty: + pass + # Restore non-audio events + for event in temp_events: + self.agent._output_queue.put_nowait(event) -async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: - """Execute tool using the complete Strands tool execution system. - - Uses proper Strands ToolExecutor system with validation, error handling, - and event streaming. - - Args: - session: BidirectionalConnection for context. - tool_use: Tool use event to execute. - """ - tool_name = tool_use.get("name") - tool_id = tool_use.get("toolUseId") - - logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id) - - try: - # Create message structure for validation - tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} - - # Use Strands validation system - tool_uses: list[ToolUse] = [] - tool_results: list[ToolResult] = [] - invalid_tool_use_ids: list[str] = [] - - validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) - - # Filter valid tools - valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] - - if not valid_tool_uses: - logger.warning("No valid tools after validation: %s", tool_name) - return - - # Create invocation state for tool execution - invocation_state = { - "agent": session.agent, - "model": session.agent.model, - "messages": session.agent.messages, - "system_prompt": session.agent.system_prompt, - } - - # Create cycle trace and span - cycle_trace = Trace("Bidirectional Tool Execution") - cycle_span = None - - tool_events = session.agent.tool_executor._execute( - session.agent, - valid_tool_uses, - tool_results, - cycle_trace, - cycle_span, - invocation_state - ) + self.interrupted = False + + if active_tool_tasks: + logger.debug("Interruption handled (tools protected) - audio cleared: %d", cleared_count) + else: + logger.debug("Interruption handled (full) - audio cleared: %d", cleared_count) + + async def _process_model_events(self) -> None: + """Process incoming provider event stream and dispatch to appropriate handlers.""" + logger.debug("Model events processor started") - # Process tool events and send results to provider - async for tool_event in tool_events: - if isinstance(tool_event, ToolResultEvent): - tool_result = tool_event.tool_result - tool_use_id = tool_result.get("toolUseId") + try: + async for provider_event in self.model_session.receive_events(): + if not self.active: + break + + if not isinstance(provider_event, dict): + continue + + strands_event = provider_event - # Send result through provider-specific session - await session.model_session.send_tool_result(tool_use_id, tool_result) - logger.debug("Tool result sent: %s", tool_use_id) + # Handle interruptions + if strands_event.get("interruptionDetected"): + logger.debug("Interruption detected from model") + await self.handle_interruption() + await self.agent._output_queue.put(strands_event) + continue - # Handle streaming events if needed later - elif isinstance(tool_event, ToolStreamEvent): - logger.debug("Tool stream event: %s", tool_event) - pass + # Schedule tool execution immediately + if strands_event.get("toolUse"): + tool_name = strands_event["toolUse"].get("name") + logger.debug("Tool request received: %s", tool_name) + self.schedule_tool_execution(strands_event["toolUse"]) + continue + + # Route audio to both queues + if strands_event.get("audioOutput"): + await self.audio_output_queue.put(strands_event) + await self.agent._output_queue.put(strands_event) + continue + + # Forward text output + if strands_event.get("textOutput"): + await self.agent._output_queue.put(strands_event) + + # Update conversation history (thread-safe) + if strands_event.get("messageStop"): + logger.debug("Adding message to conversation history") + async with self.conversation_lock: + self.agent.messages.append(strands_event["messageStop"]["message"]) + + # Handle user transcripts + if (strands_event.get("textOutput") and + strands_event["textOutput"].get("role") == "user"): + user_transcript = strands_event["textOutput"]["text"] + if user_transcript.strip(): + user_message = {"role": "user", "content": user_transcript} + async with self.conversation_lock: + self.agent.messages.append(user_message) + logger.debug("User transcript added to history") + + except Exception as e: + logger.error("Model events processor error: %s", e) + traceback.print_exc() + finally: + logger.debug("Model events processor stopped") + + async def _supervise_session(self) -> None: + """Monitor background task health using event-driven completion waiting.""" + logger.debug("Session supervisor started") - # Add tool result message to conversation history - if tool_results: - from ....hooks import MessageAddedEvent + try: + # Supervise tasks excluding self to avoid circular waiting + tasks_to_supervise = [task for task in self.background_tasks if task != asyncio.current_task()] - tool_result_message: Message = { - "role": "user", - "content": [{"toolResult": result} for result in tool_results], - } + while self.active and tasks_to_supervise: + # Wait for any task completion (deterministic vs polling) + done, pending = await asyncio.wait( + tasks_to_supervise, + return_when=asyncio.FIRST_COMPLETED, + timeout=1.0 # Periodic active flag check + ) + + # Check for task failures + for task in done: + if not task.cancelled(): + exception = task.exception() + if exception: + logger.error("Background task failed: %s", exception) + self.active = False + break + + # Remove completed tasks from supervision list + tasks_to_supervise = [task for task in tasks_to_supervise if not task.done()] - session.agent.messages.append(tool_result_message) - session.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=session.agent, message=tool_result_message)) - logger.debug("Tool result message added to history: %s", tool_name) - - logger.debug("Tool execution completed: %s", tool_name) + except Exception as e: + logger.error("Session supervisor error: %s", e) + finally: + logger.debug("Session supervisor stopped") + + async def _execute_tool_with_strands(self, tool_use: ToolUse) -> None: + """Execute tool using Strands validation and execution pipeline.""" + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") - except asyncio.CancelledError: - logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id) - raise - except Exception as e: - logger.error("Tool execution error: %s - %s", tool_name, str(e)) + logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id) - # Send error result - error_result: ToolResult = { - "toolUseId": tool_id, - "status": "error", - "content": [{"text": f"Error: {str(e)}"}] - } try: - await session.model_session.send_tool_result(tool_id, error_result) - logger.debug("Error result sent: %s", tool_id) - except Exception: - logger.error("Failed to send error result: %s", tool_id) - pass # Session might be closed + # Prepare for Strands validation system + tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} + tool_uses: list[ToolUse] = [] + tool_results: list[ToolResult] = [] + invalid_tool_use_ids: list[str] = [] + + # Validate tools + validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) + valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] + + if not valid_tool_uses: + logger.warning("No valid tools after validation: %s", tool_name) + return + + # Execute with agent context + invocation_state = { + "agent": self.agent, + "model": self.agent.model, + "messages": self.agent.messages, + "system_prompt": self.agent.system_prompt, + } + + cycle_trace = Trace("Bidirectional Tool Execution") + tool_events = self.agent.tool_executor._execute( + self.agent, valid_tool_uses, tool_results, cycle_trace, None, invocation_state + ) + + # Process tool event stream + async for tool_event in tool_events: + if isinstance(tool_event, ToolResultEvent): + tool_result = tool_event.tool_result + tool_use_id = tool_result.get("toolUseId") + await self.model_session.send_tool_result(tool_use_id, tool_result) + logger.debug("Tool result sent: %s", tool_use_id) + elif isinstance(tool_event, ToolStreamEvent): + logger.debug("Tool stream event: %s", tool_event) + + # Update conversation history (thread-safe) + if tool_results: + tool_result_message: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + + async with self.conversation_lock: + self.agent.messages.append(tool_result_message) + self.agent.hooks.invoke_callbacks( + MessageAddedEvent(agent=self.agent, message=tool_result_message) + ) + logger.debug("Tool result message added to history: %s", tool_name) + + logger.debug("Tool execution completed: %s", tool_name) + + except asyncio.CancelledError: + logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id) + raise + except Exception as e: + logger.error("Tool execution error: %s - %s", tool_name, e) + + # Send error result to provider + error_result: ToolResult = { + "toolUseId": tool_id, + "status": "error", + "content": [{"text": f"Error: {str(e)}"}] + } + + try: + await self.model_session.send_tool_result(tool_id, error_result) + logger.debug("Error result sent: %s", tool_id) + except Exception: + logger.error("Failed to send error result: %s", tool_id) + + +# Session lifecycle coordinator functions +async def start_bidirectional_connection(agent: "BidirectionalAgent") -> "BidirectionalEventLoop": + """Initialize and start bidirectional streaming session.""" + logger.debug("Creating bidirectional connection") + + model_session = await agent.model.create_bidirectional_connection( + system_prompt=agent.system_prompt, + tools=agent.tool_registry.get_all_tool_specs(), + messages=agent.messages + ) + + event_loop = BidirectionalEventLoop(model_session=model_session, agent=agent) + await event_loop.start() + + logger.debug("Bidirectional connection created and started") + return event_loop +async def stop_bidirectional_connection(event_loop: "BidirectionalEventLoop") -> None: + """Terminate bidirectional streaming session and cleanup resources.""" + await event_loop.stop() \ No newline at end of file From 1944d0d3238200682ff6db51a1d36eccbacc55db Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 28 Oct 2025 09:07:34 -0400 Subject: [PATCH 2/5] (feat): Improve bidi event loop --- .../event_loop/bidirectional_event_loop.py | 220 +++++++----------- 1 file changed, 78 insertions(+), 142 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index cd29c9e0e..6056b340e 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -2,16 +2,15 @@ import asyncio import logging -import traceback import uuid -from typing import TYPE_CHECKING, Optional, Dict, List +from typing import TYPE_CHECKING, Dict, List -from ....tools._validator import validate_and_prepare_tools +from ....hooks import MessageAddedEvent from ....telemetry.metrics import Trace +from ....tools._validator import validate_and_prepare_tools from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse -from ....hooks import MessageAddedEvent from ..models.bidirectional_model import BidirectionalModelSession if TYPE_CHECKING: @@ -20,93 +19,90 @@ logger = logging.getLogger(__name__) -class BidirectionalEventLoop: - """Event loop coordinator for bidirectional streaming sessions. - - Manages concurrent background tasks for model event processing and session supervision. - Tool execution uses immediate asyncio.Task creation (0ms scheduling) rather than polling. - Provides atomic interruption handling and race condition prevention. +class BidirectionalAgentLoop: + """Agent loop coordinator for bidirectional streaming sessions. + + Manages model event processing with tool execution. + Tool execution uses asyncio.create_task() for concurrent execution. """ def __init__(self, model_session: BidirectionalModelSession, agent: "BidirectionalAgent"): - """Initialize event loop with model session and agent dependencies.""" + """Initialize agent loop with model session and agent dependencies.""" self.model_session = model_session self.agent = agent self.active = True - + # Task tracking self.background_tasks: List[asyncio.Task] = [] self.pending_tool_tasks: Dict[str, asyncio.Task] = {} - + # Synchronization primitives self.interrupted = False self.interruption_lock = asyncio.Lock() - self.conversation_lock = asyncio.Lock() # Race condition prevention - + self.conversation_lock = asyncio.Lock() + # Audio and metrics self.audio_output_queue = asyncio.Queue() self.tool_count = 0 - - logger.debug("BidirectionalEventLoop initialized") + + logger.debug("BidirectionalAgentLoop initialized") async def start(self) -> None: - """Start background tasks for model event processing and session supervision.""" - logger.debug("Starting bidirectional event loop") - - self.background_tasks = [ - asyncio.create_task(self._process_model_events()), - asyncio.create_task(self._supervise_session()), - ] - - logger.debug("Event loop started with %d background tasks", len(self.background_tasks)) + """Start background task for model event processing.""" + logger.debug("Starting bidirectional agent loop") + + self.background_tasks = [asyncio.create_task(self._process_model_events())] + + logger.debug("Agent loop started") async def stop(self) -> None: """Gracefully shutdown and cleanup all resources.""" if not self.active: return - - logger.debug("Stopping bidirectional event loop") + + logger.debug("Stopping bidirectional agent loop") self.active = False - - # Cancel all tasks + + # Cancel pending tool tasks for task in self.pending_tool_tasks.values(): if not task.done(): task.cancel() - + + # Cancel background task for task in self.background_tasks: if not task.done(): task.cancel() - + # Wait for cancellations all_tasks = list(self.pending_tool_tasks.values()) + self.background_tasks if all_tasks: await asyncio.gather(*all_tasks, return_exceptions=True) - + # Close model session try: await self.model_session.close() except Exception as e: logger.warning("Error closing model session: %s", e) - - logger.debug("Event loop stopped - tools executed: %d", self.tool_count) + + logger.debug("Agent loop stopped - tools executed: %d", self.tool_count) def schedule_tool_execution(self, tool_use: ToolUse) -> None: - """Create asyncio task for immediate tool execution (0ms scheduling).""" + """Create asyncio task for tool execution and add to pending tasks tracking.""" tool_name = tool_use.get("name") tool_id = tool_use.get("toolUseId") - + # Thread-safe counter increment current_tool_number = self.tool_count + 1 self.tool_count = current_tool_number print(f"\nTool #{current_tool_number}: {tool_name}") - + logger.debug("Scheduling tool execution: %s (id: %s)", tool_name, tool_id) - + # Create task with UUID tracking task_id = str(uuid.uuid4()) - task = asyncio.create_task(self._execute_tool_with_strands(tool_use)) + task = asyncio.create_task(self._execute_tool(tool_use)) self.pending_tool_tasks[task_id] = task - + def cleanup_task(completed_task: asyncio.Task) -> None: self.pending_tool_tasks.pop(task_id, None) if completed_task.cancelled(): @@ -115,14 +111,14 @@ def cleanup_task(completed_task: asyncio.Task) -> None: logger.error("Tool task error: %s - %s", task_id, completed_task.exception()) else: logger.debug("Tool task completed: %s", task_id) - + task.add_done_callback(cleanup_task) async def handle_interruption(self) -> None: """Execute atomic interruption handling with race condition prevention. - - Always clears audio buffers for responsive interruption. - Protects tool execution by not cancelling tools when they are running. + + Clears audio buffers and checks tool execution status. + Does not cancel running tools - only clears audio output queues. """ async with self.interruption_lock: if self.interrupted: @@ -134,14 +130,13 @@ async def handle_interruption(self) -> None: # Check if tools are currently executing active_tool_tasks = [task for task in self.pending_tool_tasks.values() if not task.done()] - + if active_tool_tasks: logger.debug("Tools are protected - %d tools currently executing", len(active_tool_tasks)) - # Don't cancel tools, but still clear audio for responsive interruption else: logger.debug("No active tools - full interruption handling") - # Always clear audio queues for responsive interruption (regardless of tool status) + # Clear audio output queue cleared_count = 0 while True: try: @@ -167,127 +162,93 @@ async def handle_interruption(self) -> None: self.agent._output_queue.put_nowait(event) self.interrupted = False - + if active_tool_tasks: logger.debug("Interruption handled (tools protected) - audio cleared: %d", cleared_count) else: logger.debug("Interruption handled (full) - audio cleared: %d", cleared_count) async def _process_model_events(self) -> None: - """Process incoming provider event stream and dispatch to appropriate handlers.""" - logger.debug("Model events processor started") - + """Process incoming provider event stream and dispatch events to handlers.""" + logger.debug("Agent loop processor started") + try: async for provider_event in self.model_session.receive_events(): if not self.active: break - + if not isinstance(provider_event, dict): continue - + strands_event = provider_event - + # Handle interruptions if strands_event.get("interruptionDetected"): logger.debug("Interruption detected from model") await self.handle_interruption() await self.agent._output_queue.put(strands_event) continue - - # Schedule tool execution immediately + + # Schedule tool execution if strands_event.get("toolUse"): tool_name = strands_event["toolUse"].get("name") logger.debug("Tool request received: %s", tool_name) self.schedule_tool_execution(strands_event["toolUse"]) continue - + # Route audio to both queues if strands_event.get("audioOutput"): await self.audio_output_queue.put(strands_event) await self.agent._output_queue.put(strands_event) continue - + # Forward text output if strands_event.get("textOutput"): await self.agent._output_queue.put(strands_event) - - # Update conversation history (thread-safe) + + # Update conversation history if strands_event.get("messageStop"): logger.debug("Adding message to conversation history") async with self.conversation_lock: self.agent.messages.append(strands_event["messageStop"]["message"]) - + # Handle user transcripts - if (strands_event.get("textOutput") and - strands_event["textOutput"].get("role") == "user"): + if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": user_transcript = strands_event["textOutput"]["text"] if user_transcript.strip(): user_message = {"role": "user", "content": user_transcript} async with self.conversation_lock: self.agent.messages.append(user_message) logger.debug("User transcript added to history") - - except Exception as e: - logger.error("Model events processor error: %s", e) - traceback.print_exc() - finally: - logger.debug("Model events processor stopped") - async def _supervise_session(self) -> None: - """Monitor background task health using event-driven completion waiting.""" - logger.debug("Session supervisor started") - - try: - # Supervise tasks excluding self to avoid circular waiting - tasks_to_supervise = [task for task in self.background_tasks if task != asyncio.current_task()] - - while self.active and tasks_to_supervise: - # Wait for any task completion (deterministic vs polling) - done, pending = await asyncio.wait( - tasks_to_supervise, - return_when=asyncio.FIRST_COMPLETED, - timeout=1.0 # Periodic active flag check - ) - - # Check for task failures - for task in done: - if not task.cancelled(): - exception = task.exception() - if exception: - logger.error("Background task failed: %s", exception) - self.active = False - break - - # Remove completed tasks from supervision list - tasks_to_supervise = [task for task in tasks_to_supervise if not task.done()] - except Exception as e: - logger.error("Session supervisor error: %s", e) + logger.error("Agent loop processor failed: %s", e) + self.active = False finally: - logger.debug("Session supervisor stopped") + logger.debug("Agent loop processor stopped") - async def _execute_tool_with_strands(self, tool_use: ToolUse) -> None: + async def _execute_tool(self, tool_use: ToolUse) -> None: """Execute tool using Strands validation and execution pipeline.""" tool_name = tool_use.get("name") tool_id = tool_use.get("toolUseId") - + logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id) - + try: # Prepare for Strands validation system tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} tool_uses: list[ToolUse] = [] tool_results: list[ToolResult] = [] invalid_tool_use_ids: list[str] = [] - + # Validate tools validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] - + if not valid_tool_uses: logger.warning("No valid tools after validation: %s", tool_name) return - + # Execute with agent context invocation_state = { "agent": self.agent, @@ -295,12 +256,12 @@ async def _execute_tool_with_strands(self, tool_use: ToolUse) -> None: "messages": self.agent.messages, "system_prompt": self.agent.system_prompt, } - + cycle_trace = Trace("Bidirectional Tool Execution") tool_events = self.agent.tool_executor._execute( self.agent, valid_tool_uses, tool_results, cycle_trace, None, invocation_state ) - + # Process tool event stream async for tool_event in tool_events: if isinstance(tool_event, ToolResultEvent): @@ -310,61 +271,36 @@ async def _execute_tool_with_strands(self, tool_use: ToolUse) -> None: logger.debug("Tool result sent: %s", tool_use_id) elif isinstance(tool_event, ToolStreamEvent): logger.debug("Tool stream event: %s", tool_event) - - # Update conversation history (thread-safe) + + # Update conversation history using conversation lock if tool_results: tool_result_message: Message = { "role": "user", "content": [{"toolResult": result} for result in tool_results], } - + async with self.conversation_lock: self.agent.messages.append(tool_result_message) - self.agent.hooks.invoke_callbacks( - MessageAddedEvent(agent=self.agent, message=tool_result_message) - ) + self.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=self.agent, message=tool_result_message)) logger.debug("Tool result message added to history: %s", tool_name) - + logger.debug("Tool execution completed: %s", tool_name) - + except asyncio.CancelledError: logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id) raise except Exception as e: logger.error("Tool execution error: %s - %s", tool_name, e) - + # Send error result to provider error_result: ToolResult = { "toolUseId": tool_id, "status": "error", - "content": [{"text": f"Error: {str(e)}"}] + "content": [{"text": f"Error: {str(e)}"}], } - + try: await self.model_session.send_tool_result(tool_id, error_result) logger.debug("Error result sent: %s", tool_id) except Exception: logger.error("Failed to send error result: %s", tool_id) - - -# Session lifecycle coordinator functions -async def start_bidirectional_connection(agent: "BidirectionalAgent") -> "BidirectionalEventLoop": - """Initialize and start bidirectional streaming session.""" - logger.debug("Creating bidirectional connection") - - model_session = await agent.model.create_bidirectional_connection( - system_prompt=agent.system_prompt, - tools=agent.tool_registry.get_all_tool_specs(), - messages=agent.messages - ) - - event_loop = BidirectionalEventLoop(model_session=model_session, agent=agent) - await event_loop.start() - - logger.debug("Bidirectional connection created and started") - return event_loop - - -async def stop_bidirectional_connection(event_loop: "BidirectionalEventLoop") -> None: - """Terminate bidirectional streaming session and cleanup resources.""" - await event_loop.stop() \ No newline at end of file From 288d9fedf7bc15f97f46da4b0d15c936dee49368 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 28 Oct 2025 09:22:45 -0400 Subject: [PATCH 3/5] (feat): Improve bidi event loop --- .../event_loop/bidirectional_event_loop.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 6056b340e..d9e810c27 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -41,8 +41,8 @@ def __init__(self, model_session: BidirectionalModelSession, agent: "Bidirection self.interruption_lock = asyncio.Lock() self.conversation_lock = asyncio.Lock() - # Audio and metrics - self.audio_output_queue = asyncio.Queue() + # Output queue and metrics + self.event_output_queue = asyncio.Queue() self.tool_count = 0 logger.debug("BidirectionalAgentLoop initialized") @@ -136,11 +136,11 @@ async def handle_interruption(self) -> None: else: logger.debug("No active tools - full interruption handling") - # Clear audio output queue + # Clear output queue cleared_count = 0 while True: try: - self.audio_output_queue.get_nowait() + self.event_output_queue.get_nowait() cleared_count += 1 except asyncio.QueueEmpty: break @@ -198,7 +198,7 @@ async def _process_model_events(self) -> None: # Route audio to both queues if strands_event.get("audioOutput"): - await self.audio_output_queue.put(strands_event) + await self.event_output_queue.put(strands_event) await self.agent._output_queue.put(strands_event) continue From 11d7ad74272e01a51ea0f9fcdd5761d5058fc5a4 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 28 Oct 2025 09:23:09 -0400 Subject: [PATCH 4/5] (feat): Improve bidi event loop --- .../event_loop/bidirectional_event_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index d9e810c27..2982ada42 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -91,7 +91,6 @@ def schedule_tool_execution(self, tool_use: ToolUse) -> None: tool_name = tool_use.get("name") tool_id = tool_use.get("toolUseId") - # Thread-safe counter increment current_tool_number = self.tool_count + 1 self.tool_count = current_tool_number print(f"\nTool #{current_tool_number}: {tool_name}") From 70edc6ab18837b1dfbc60c57dee7c2b77f07aef2 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 28 Oct 2025 09:33:58 -0400 Subject: [PATCH 5/5] (feat): Improve bidi event loop --- .../bidirectional_streaming/event_loop/__init__.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py index af8c4e1e1..8e8f121dd 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py @@ -1,15 +1,9 @@ """Event loop management for bidirectional streaming.""" from .bidirectional_event_loop import ( - BidirectionalConnection, - bidirectional_event_loop_cycle, - start_bidirectional_connection, - stop_bidirectional_connection, + BidirectionalAgentLoop, ) __all__ = [ - "BidirectionalConnection", - "start_bidirectional_connection", - "stop_bidirectional_connection", - "bidirectional_event_loop_cycle", + "BidirectionalAgentLoop", ]