diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 0f842ee9f..dcbdb4e07 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -2,9 +2,8 @@ # Main components - Primary user interface from .agent.agent import BidirectionalAgent - -# Advanced interfaces (for custom implementations) -from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .models.base_model import BidirectionalModel +from .models.base_session import BidirectionalModelSession # Model providers - What users need to create models from .models.novasonic import NovaSonicBidirectionalModel diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 26b964c53..565c985a1 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -30,10 +30,9 @@ from ....types.tools import ToolResult, ToolUse from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection -from ..models.bidirectional_model import BidirectionalModel +from ..models.base_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent - logger = logging.getLogger(__name__) _DEFAULT_AGENT_NAME = "Strands Agents" @@ -81,7 +80,7 @@ def caller( Args: user_message_override: Optional custom message to record instead of default - record_direct_tool_call: Whether to record direct tool calls in message history. + record_direct_tool_call: Whether to record direct tool calls in message history. For bidirectional agents, this is always True to maintain conversation history. **kwargs: Keyword arguments to pass to the tool. @@ -186,12 +185,12 @@ def __init__( self.model = model self.system_prompt = system_prompt self.messages = messages or [] - + # Agent identification self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME self.description = description - + # Tool execution configuration self.record_direct_tool_call = record_direct_tool_call self.load_tools_from_directory = load_tools_from_directory @@ -207,25 +206,25 @@ def __init__( # Initialize tool registry self.tool_registry = ToolRegistry() - + if tools is not None: self.tool_registry.process_tools(tools) - + self.tool_registry.initialize_tools(self.load_tools_from_directory) - + # Initialize tool watcher if directory loading is enabled if self.load_tools_from_directory: self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry) # Initialize tool executor self.tool_executor = tool_executor or ConcurrentToolExecutor() - + # Initialize hooks system self.hooks = HookRegistry() if hooks: for hook in hooks: self.hooks.add_hook(hook) - + # Initialize other components self.event_loop_metrics = EventLoopMetrics() self.tool_caller = BidirectionalAgent.ToolCaller(self) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 69f5d759d..df0f393f7 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -16,13 +16,13 @@ import traceback import uuid -from ....tools._validator import validate_and_prepare_tools from ....telemetry.metrics import Trace +from ....tools._validator import validate_and_prepare_tools from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse -from ..models.bidirectional_model import BidirectionalModelSession - +from .. import BidirectionalAgent +from ..models.base_session import BidirectionalModelSession logger = logging.getLogger(__name__) @@ -60,7 +60,7 @@ def __init__(self, model_session: BidirectionalModelSession, agent: "Bidirection # Interruption handling (model-agnostic) self.interrupted = False self.interruption_lock = asyncio.Lock() - + # Tool execution tracking self.tool_count = 0 @@ -265,7 +265,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Basic validation - skip invalid events if not isinstance(provider_event, dict): continue - + strands_event = provider_event # Handle interruption detection (provider converts raw patterns to interruptionDetected) @@ -291,7 +291,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: if strands_event.get("messageStop"): logger.debug("Message added to history") session.agent.messages.append(strands_event["messageStop"]["message"]) - + # Handle user audio transcripts - add to message history if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": user_transcript = strands_event["textOutput"]["text"] @@ -311,7 +311,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. Background task that manages tool execution without blocking model event - processing or user interaction. Uses proper asyncio cancellation for + processing or user interaction. Uses proper asyncio cancellation for interruption handling rather than manual state checks. Args: @@ -323,10 +323,10 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) tool_name = tool_use.get("name") tool_id = tool_use.get("toolUseId") - + session.tool_count += 1 print(f"\nTool #{session.tool_count}: {tool_name}") - + logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id) task_id = str(uuid.uuid4()) @@ -372,42 +372,39 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: logger.debug("Tool execution processor stopped") - - - async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: """Execute tool using the complete Strands tool execution system. - + Uses proper Strands ToolExecutor system with validation, error handling, and event streaming. - + Args: session: BidirectionalConnection for context. tool_use: Tool use event to execute. """ tool_name = tool_use.get("name") tool_id = tool_use.get("toolUseId") - + logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id) - + try: - # Create message structure for validation + # Create message structure for validation tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} - + # Use Strands validation system tool_uses: list[ToolUse] = [] tool_results: list[ToolResult] = [] invalid_tool_use_ids: list[str] = [] - + validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) - + # Filter valid tools valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] - + if not valid_tool_uses: logger.warning("No valid tools after validation: %s", tool_name) return - + # Create invocation state for tool execution invocation_state = { "agent": session.agent, @@ -415,67 +412,56 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: "messages": session.agent.messages, "system_prompt": session.agent.system_prompt, } - + # Create cycle trace and span cycle_trace = Trace("Bidirectional Tool Execution") cycle_span = None - + tool_events = session.agent.tool_executor._execute( - session.agent, - valid_tool_uses, - tool_results, - cycle_trace, - cycle_span, - invocation_state + session.agent, valid_tool_uses, tool_results, cycle_trace, cycle_span, invocation_state ) - + # Process tool events and send results to provider async for tool_event in tool_events: if isinstance(tool_event, ToolResultEvent): tool_result = tool_event.tool_result tool_use_id = tool_result.get("toolUseId") - + # Send result through provider-specific session await session.model_session.send_tool_result(tool_use_id, tool_result) logger.debug("Tool result sent: %s", tool_use_id) - + # Handle streaming events if needed later elif isinstance(tool_event, ToolStreamEvent): logger.debug("Tool stream event: %s", tool_event) pass - + # Add tool result message to conversation history if tool_results: from ....hooks import MessageAddedEvent - + tool_result_message: Message = { "role": "user", "content": [{"toolResult": result} for result in tool_results], } - + session.agent.messages.append(tool_result_message) session.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=session.agent, message=tool_result_message)) logger.debug("Tool result message added to history: %s", tool_name) - + logger.debug("Tool execution completed: %s", tool_name) - + except asyncio.CancelledError: logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id) raise except Exception as e: logger.error("Tool execution error: %s - %s", tool_name, str(e)) - - # Send error result - error_result: ToolResult = { - "toolUseId": tool_id, - "status": "error", - "content": [{"text": f"Error: {str(e)}"}] - } + + # Send error result + error_result: ToolResult = {"toolUseId": tool_id, "status": "error", "content": [{"text": f"Error: {str(e)}"}]} try: await session.model_session.send_tool_result(tool_id, error_result) logger.debug("Error result sent: %s", tool_id) except Exception: logger.error("Failed to send error result: %s", tool_id) pass # Session might be closed - - diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 3a785e98a..a4d75de9b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,6 +1,7 @@ """Bidirectional model interfaces and implementations.""" -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .base_model import BidirectionalModel +from .base_session import BidirectionalModelSession from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession __all__ = [ diff --git a/src/strands/experimental/bidirectional_streaming/models/base_model.py b/src/strands/experimental/bidirectional_streaming/models/base_model.py new file mode 100644 index 000000000..5da291be9 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/base_model.py @@ -0,0 +1,45 @@ +"""Unified bidirectional streaming interface. + +Single layer combining model and session abstractions for simpler implementation. +""" + +from typing import AsyncIterable, Protocol, Union + +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec +from ..types.bidirectional_streaming import ( + AudioInputEvent, + BidirectionalStreamEvent, + ImageInputEvent, + TextInputEvent, +) + + +class BaseModel(Protocol): + """Unified interface for bidirectional streaming models. + + Combines model configuration and session communication in a single abstraction. + Providers implement this directly without separate model/session classes. + """ + + async def connect( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> None: + """Establish bidirectional connection with the model.""" + ... + + async def close(self) -> None: + """Close connection and cleanup resources.""" + ... + + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive events from the model in standardized format.""" + ... + + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Send structured content to the model.""" + ... \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py deleted file mode 100644 index d5c3c9b65..000000000 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Bidirectional model interface for real-time streaming conversations. - -Defines the interface for models that support bidirectional streaming capabilities. -Provides abstractions for different model providers with connection-based communication -patterns that support real-time audio and text interaction. - -Features: -- connection-based persistent connections -- Real-time bidirectional communication -- Provider-agnostic event normalization -- Tool execution integration -""" - -import abc -import logging -from typing import AsyncIterable - -from ....types.content import Messages -from ....types.tools import ToolSpec -from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent - -logger = logging.getLogger(__name__) - - -class BidirectionalModelSession(abc.ABC): - """Abstract interface for model-specific bidirectional communication connections. - - Defines the contract for managing persistent streaming connections with individual - model providers, handling audio/text input, receiving events, and managing - tool execution results. - """ - - @abc.abstractmethod - async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive events from the model in standardized format. - - Converts provider-specific events to a common format that can be - processed uniformly by the event loop. - """ - raise NotImplementedError - - @abc.abstractmethod - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to the model during an active connection. - - Handles audio encoding and provider-specific formatting while presenting - a simple AudioInputEvent interface. - """ - raise NotImplementedError - - @abc.abstractmethod - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content to the model during ongoing generation. - - Allows natural interruption and follow-up questions without requiring - connection restart. - """ - raise NotImplementedError - - @abc.abstractmethod - async def send_interrupt(self) -> None: - """Send interruption signal to stop generation immediately. - - Enables responsive conversational experiences where users can - naturally interrupt during model responses. - """ - raise NotImplementedError - - @abc.abstractmethod - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool execution result to the model. - - Formats and sends tool results according to the provider's specific protocol. - Handles both successful results and error cases through the result dictionary. - """ - raise NotImplementedError - - @abc.abstractmethod - async def close(self) -> None: - """Close the connection and cleanup resources.""" - raise NotImplementedError - - -class BidirectionalModel(abc.ABC): - """Interface for models that support bidirectional streaming. - - Defines the contract for creating persistent streaming connections that support - real-time audio and text communication with AI models. - """ - - @abc.abstractmethod - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create a bidirectional connection with the model. - - Establishes a persistent connection for real-time communication while - abstracting provider-specific initialization requirements. - """ - raise NotImplementedError diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 7f35a3c1c..c5e34b8bf 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -35,10 +35,10 @@ BidirectionalConnectionStartEvent, InterruptionDetectedEvent, TextOutputEvent, - UsageMetricsEvent + UsageMetricsEvent, ) - -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .base_model import BidirectionalModel +from .base_session import BidirectionalModelSession logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index c0f6eb209..2efe42ee6 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -70,6 +70,15 @@ class AudioInputEvent(TypedDict): sampleRate: Literal[16000, 24000, 48000] channels: Literal[1, 2] +class TextInputEvent(TypedDict): + """Text input event for sending text to the model. + + Attributes: + text: The text content to send to the model. + """ + + text: str + role: Role class TextOutputEvent(TypedDict): """Text output event from the model during bidirectional streaming. @@ -83,6 +92,26 @@ class TextOutputEvent(TypedDict): role: Role +class ImageInputEvent(TypedDict): + """Image input event for sending images/video frames to the model. + + Supports multiple input methods following OpenAI realtime API patterns: + - Base64 data URLs (data:image/png;base64,...) + - Hosted URLs (https://...) + - OpenAI file IDs (file-...) + - Raw bytes with MIME type + + Attributes: + image_url: Data URL, hosted URL, or OpenAI file ID. + imageData: Raw image bytes (alternative to image_url). + mimeType: MIME type when using imageData. + """ + + image_url: Optional[str] # Primary: data URL, hosted URL, or file ID + imageData: Optional[bytes] # Alternative: raw bytes + mimeType: Optional[str] # Required when using imageData + + class InterruptionDetectedEvent(TypedDict): """Interruption detection event. @@ -120,6 +149,7 @@ class BidirectionalConnectionEndEvent(TypedDict): connectionId: Optional[str] metadata: Optional[Dict[str, Any]] + class UsageMetricsEvent(TypedDict): """Token usage and performance tracking. @@ -162,4 +192,3 @@ class BidirectionalStreamEvent(StreamEvent, total=False): BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent] BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent] usageMetrics: Optional[UsageMetricsEvent] -