From 9368c82c76f6ca858d355c68624e078c8b95cf4e Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 14 Oct 2025 08:30:23 -0400 Subject: [PATCH 1/2] feat(tool_executor): Plug tool executor into bidirectional streaming implementation --- .../bidirectional_streaming/__init__.py | 41 +- .../bidirectional_streaming/agent/agent.py | 297 +++++++++- .../event_loop/bidirectional_event_loop.py | 179 +++--- .../models/__init__.py | 9 +- .../models/novasonic.py | 23 +- .../bidirectional_streaming/models/openai.py | 522 ++++++++++++++++++ .../tests/test_bidi_openai.py | 317 +++++++++++ .../tests/test_bidirectional_streaming.py | 27 + 8 files changed, 1317 insertions(+), 98 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/models/openai.py create mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 52822711a..844a8a1f8 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -1,2 +1,41 @@ -"""Bidirectional streaming package for real-time audio/text conversations.""" +"""Bidirectional streaming package.""" +# Main components - Primary user interface +from .agent.agent import BidirectionalAgent + +# Advanced interfaces (for custom implementations) +from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession + +# Model providers - What users need to create models +from .models.novasonic import NovaSonicBidirectionalModel +from .models.openai import OpenAIRealtimeBidirectionalModel + +# Event types - For type hints and event handling +from .types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalStreamEvent, + InterruptionDetectedEvent, + TextOutputEvent, + UsageMetricsEvent, + VoiceActivityEvent, +) + +__all__ = [ + # Main interface + "BidirectionalAgent", + # Model providers + "NovaSonicBidirectionalModel", + "OpenAIRealtimeBidirectionalModel", + # Event types + "AudioInputEvent", + "AudioOutputEvent", + "TextOutputEvent", + "InterruptionDetectedEvent", + "BidirectionalStreamEvent", + "VoiceActivityEvent", + "UsageMetricsEvent", + # Model interface + "BidirectionalModel", + "BidirectionalModelSession", +] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 68d371a51..26b964c53 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -13,12 +13,22 @@ """ import asyncio +import json import logging -from typing import AsyncIterable +import random +from concurrent.futures import ThreadPoolExecutor +from typing import Any, AsyncIterable, Callable, Mapping, Optional +from .... import _identifier +from ....hooks import HookProvider, HookRegistry +from ....telemetry.metrics import EventLoopMetrics from ....tools.executors import ConcurrentToolExecutor +from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry -from ....types.content import Messages +from ....tools.watcher import ToolWatcher +from ....types.content import Message, Messages +from ....types.tools import ToolResult, ToolUse +from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent @@ -26,6 +36,9 @@ logger = logging.getLogger(__name__) +_DEFAULT_AGENT_NAME = "Strands Agents" +_DEFAULT_AGENT_ID = "default" + class BidirectionalAgent: """Agent for bidirectional streaming conversations. @@ -34,12 +47,125 @@ class BidirectionalAgent: sessions. Supports concurrent tool execution and interruption handling. """ + class ToolCaller: + """Call tool as a function for bidirectional agent.""" + + def __init__(self, agent: "BidirectionalAgent") -> None: + """Initialize tool caller with agent reference.""" + # WARNING: Do not add any other member variables or methods as this could result in a name conflict with + # agent tools and thus break their execution. + self._agent = agent + + def __getattr__(self, name: str) -> Callable[..., Any]: + """Call tool as a function. + + This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). + It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). + + Args: + name: The name of the attribute (tool) being accessed. + + Returns: + A function that when called will execute the named tool. + + Raises: + AttributeError: If no tool with the given name exists or if multiple tools match the given name. + """ + + def caller( + user_message_override: Optional[str] = None, + record_direct_tool_call: Optional[bool] = None, + **kwargs: Any, + ) -> Any: + """Call a tool directly by name. + + Args: + user_message_override: Optional custom message to record instead of default + record_direct_tool_call: Whether to record direct tool calls in message history. + For bidirectional agents, this is always True to maintain conversation history. + **kwargs: Keyword arguments to pass to the tool. + + Returns: + The result returned by the tool. + + Raises: + AttributeError: If the tool doesn't exist. + """ + normalized_name = self._find_normalized_tool_name(name) + + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + tool_results: list[ToolResult] = [] + invocation_state = kwargs + + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + _ = event + + return tool_results[0] + + def tcall() -> ToolResult: + return asyncio.run(acall()) + + with ThreadPoolExecutor() as executor: + future = executor.submit(tcall) + tool_result = future.result() + + # Always record direct tool calls for bidirectional agents to maintain conversation history + # Use agent's record_direct_tool_call setting if not overridden + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call + + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + self._agent._record_tool_execution(tool_use, tool_result, user_message_override) + + return tool_result + + return caller + + def _find_normalized_tool_name(self, name: str) -> str: + """Lookup the tool represented by name, replacing characters with underscores as necessary.""" + tool_registry = self._agent.tool_registry.registry + + if tool_registry.get(name, None): + return name + + # If the desired name contains underscores, it might be a placeholder for characters that can't be + # represented as python identifiers but are valid as tool names, such as dashes. In that case, find + # all tools that can be represented with the normalized name + if "_" in name: + filtered_tools = [ + tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name + ] + + # The registry itself defends against similar names, so we can just take the first match + if filtered_tools: + return filtered_tools[0] + + raise AttributeError(f"Tool '{name}' not found") + def __init__( self, model: BidirectionalModel, tools: list | None = None, system_prompt: str | None = None, messages: Messages | None = None, + record_direct_tool_call: bool = True, + load_tools_from_directory: bool = False, + agent_id: Optional[str] = None, + name: Optional[str] = None, + tool_executor: Optional[ToolExecutor] = None, + hooks: Optional[list[HookProvider]] = None, + trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + description: Optional[str] = None, ): """Initialize bidirectional agent with required model and optional configuration. @@ -48,24 +174,177 @@ def __init__( tools: Optional list of tools available to the model. system_prompt: Optional system prompt for conversations. messages: Optional conversation history to initialize with. + record_direct_tool_call: Whether to record direct tool calls in message history. + load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. + agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios. + name: Name of the Agent. + tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). + hooks: Hooks to be added to the agent hook registry. + trace_attributes: Custom trace attributes to apply to the agent's trace span. + description: Description of what the Agent does. """ self.model = model self.system_prompt = system_prompt self.messages = messages or [] - - # Initialize tool registry using existing Strands infrastructure + + # Agent identification + self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) + self.name = name or _DEFAULT_AGENT_NAME + self.description = description + + # Tool execution configuration + self.record_direct_tool_call = record_direct_tool_call + self.load_tools_from_directory = load_tools_from_directory + + # Process trace attributes to ensure they're of compatible types + self.trace_attributes: dict[str, AttributeValue] = {} + if trace_attributes: + for k, v in trace_attributes.items(): + if isinstance(v, (str, int, float, bool)) or ( + isinstance(v, list) and all(isinstance(x, (str, int, float, bool)) for x in v) + ): + self.trace_attributes[k] = v + + # Initialize tool registry self.tool_registry = ToolRegistry() - if tools: + + if tools is not None: self.tool_registry.process_tools(tools) - self.tool_registry.initialize_tools() - - # Initialize tool executor for concurrent execution - self.tool_executor = ConcurrentToolExecutor() + + self.tool_registry.initialize_tools(self.load_tools_from_directory) + + # Initialize tool watcher if directory loading is enabled + if self.load_tools_from_directory: + self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry) + + # Initialize tool executor + self.tool_executor = tool_executor or ConcurrentToolExecutor() + + # Initialize hooks system + self.hooks = HookRegistry() + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + + # Initialize other components + self.event_loop_metrics = EventLoopMetrics() + self.tool_caller = BidirectionalAgent.ToolCaller(self) # Session management self._session = None self._output_queue = asyncio.Queue() + @property + def tool(self) -> ToolCaller: + """Call tool as a function. + + Returns: + Tool caller through which user can invoke tool as a function. + + Example: + ``` + agent = BidirectionalAgent(model=model, tools=[calculator]) + agent.tool.calculator(expression="2+2") + ``` + """ + return self.tool_caller + + @property + def tool_names(self) -> list[str]: + """Get a list of all registered tool names. + + Returns: + Names of all tools available to this agent. + """ + all_tools = self.tool_registry.get_all_tools_config() + return list(all_tools.keys()) + + def _record_tool_execution( + self, + tool: ToolUse, + tool_result: ToolResult, + user_message_override: Optional[str], + ) -> None: + """Record a tool execution in the message history. + + Creates a sequence of messages that represent the tool execution: + + 1. A user message describing the tool call + 2. An assistant message with the tool use + 3. A user message with the tool result + 4. An assistant message acknowledging the tool call + + Args: + tool: The tool call information. + tool_result: The result returned by the tool. + user_message_override: Optional custom message to include. + """ + # Filter tool input parameters to only include those defined in tool spec + filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) + + # Create user message describing the tool call + input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") + + user_msg_content = [ + {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} + ] + + # Add override message if provided + if user_message_override: + user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) + + # Create filtered tool use for message history + filtered_tool: ToolUse = { + "toolUseId": tool["toolUseId"], + "name": tool["name"], + "input": filtered_input, + } + + # Create the message sequence + user_msg: Message = { + "role": "user", + "content": user_msg_content, + } + tool_use_msg: Message = { + "role": "assistant", + "content": [{"toolUse": filtered_tool}], + } + tool_result_msg: Message = { + "role": "user", + "content": [{"toolResult": tool_result}], + } + assistant_msg: Message = { + "role": "assistant", + "content": [{"text": f"agent.tool.{tool['name']} was called."}], + } + + # Add to message history + self.messages.append(user_msg) + self.messages.append(tool_use_msg) + self.messages.append(tool_result_msg) + self.messages.append(assistant_msg) + + logger.debug("Direct tool call recorded in message history: %s", tool["name"]) + + def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: + """Filter input parameters to only include those defined in the tool specification. + + Args: + tool_name: Name of the tool to get specification for + input_params: Original input parameters + + Returns: + Filtered parameters containing only those defined in tool spec + """ + all_tools_config = self.tool_registry.get_all_tools_config() + tool_spec = all_tools_config.get(tool_name) + + if not tool_spec or "inputSchema" not in tool_spec: + return input_params.copy() + + properties = tool_spec["inputSchema"]["json"]["properties"] + return {k: v for k, v in input_params.items() if k in properties} + async def start(self) -> None: """Start a persistent bidirectional conversation session. diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 16be08aaf..69f5d759d 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -12,12 +12,13 @@ """ import asyncio -import json import logging import traceback import uuid from ....tools._validator import validate_and_prepare_tools +from ....telemetry.metrics import Trace +from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModelSession @@ -59,6 +60,9 @@ def __init__(self, model_session: BidirectionalModelSession, agent: "Bidirection # Interruption handling (model-agnostic) self.interrupted = False self.interruption_lock = asyncio.Lock() + + # Tool execution tracking + self.tool_count = 0 async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: @@ -195,11 +199,11 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: # Cancel all pending tool execution tasks cancelled_tools = 0 - for task_id, task in list(session.pending_tool_tasks.items()): + for _task_id, task in list(session.pending_tool_tasks.items()): if not task.done(): task.cancel() cancelled_tools += 1 - logger.debug("Tool task cancelled: %s", task_id) + logger.debug("Tool task cancelled: %s", _task_id) if cancelled_tools > 0: logger.debug("Tool tasks cancelled: %d", cancelled_tools) @@ -274,7 +278,8 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Queue tool requests for concurrent execution if strands_event.get("toolUse"): - logger.debug("Tool queued: %s", strands_event["toolUse"].get("name")) + tool_name = strands_event["toolUse"].get("name") + logger.debug("Tool usage detected: %s", tool_name) await session.tool_queue.put(strands_event["toolUse"]) continue @@ -316,7 +321,13 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: while session.active: try: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) - logger.debug("Tool execution started: %s (id: %s)", tool_use.get("name"), tool_use.get("toolUseId")) + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") + + session.tool_count += 1 + print(f"\nTool #{session.tool_count}: {tool_name}") + + logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id) task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) @@ -330,11 +341,11 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: # Log completion status if completed_task.cancelled(): - logger.debug("Tool task cleanup cancelled: %s", task_id) + logger.debug("Tool task cancelled: %s", task_id) elif completed_task.exception(): - logger.error("Tool task cleanup error: %s - %s", task_id, str(completed_task.exception())) + logger.error("Tool task error: %s - %s", task_id, str(completed_task.exception())) else: - logger.debug("Tool task cleanup success: %s", task_id) + logger.debug("Tool task completed: %s", task_id) except Exception as e: logger.error("Tool task cleanup failed: %s - %s", task_id, str(e)) @@ -365,94 +376,106 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: - """Execute tool using Strands infrastructure with interruption support. - - Executes tools using the existing Strands tool system with proper asyncio - cancellation handling. Tool execution is stopped via task cancellation, - not manual state checks. - + """Execute tool using the complete Strands tool execution system. + + Uses proper Strands ToolExecutor system with validation, error handling, + and event streaming. + Args: session: BidirectionalConnection for context. tool_use: Tool use event to execute. """ tool_name = tool_use.get("name") tool_id = tool_use.get("toolUseId") - + + logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id) + try: - # Create message structure for existing tool system + # Create message structure for validation tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} - + + # Use Strands validation system tool_uses: list[ToolUse] = [] tool_results: list[ToolResult] = [] invalid_tool_use_ids: list[str] = [] - - # Validate using existing Strands validation + validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) - - # Filter valid tool uses + + # Filter valid tools valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] - + if not valid_tool_uses: - logger.warning("Tool validation failed: %s (id: %s)", tool_name, tool_id) + logger.warning("No valid tools after validation: %s", tool_name) return - - # Execute tools directly (simpler approach for bidirectional) - for tool_use in valid_tool_uses: - tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) - - if tool_func: - try: - actual_func = _extract_callable_function(tool_func) - - # Execute tool function with provided input - result = actual_func(**tool_use.get("input", {})) - - tool_result = _create_success_result(tool_use["toolUseId"], result) - tool_results.append(tool_result) - - except Exception as e: - logger.error("Tool execution failed: %s - %s", tool_name, str(e)) - tool_result = _create_error_result(tool_use["toolUseId"], str(e)) - tool_results.append(tool_result) - else: - logger.warning("Tool not found: %s", tool_name) - - # Send results through provider-specific session - for result in tool_results: - await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) - - logger.debug("Tool execution completed: %s (%d results)", tool_name, len(tool_results)) - + + # Create invocation state for tool execution + invocation_state = { + "agent": session.agent, + "model": session.agent.model, + "messages": session.agent.messages, + "system_prompt": session.agent.system_prompt, + } + + # Create cycle trace and span + cycle_trace = Trace("Bidirectional Tool Execution") + cycle_span = None + + tool_events = session.agent.tool_executor._execute( + session.agent, + valid_tool_uses, + tool_results, + cycle_trace, + cycle_span, + invocation_state + ) + + # Process tool events and send results to provider + async for tool_event in tool_events: + if isinstance(tool_event, ToolResultEvent): + tool_result = tool_event.tool_result + tool_use_id = tool_result.get("toolUseId") + + # Send result through provider-specific session + await session.model_session.send_tool_result(tool_use_id, tool_result) + logger.debug("Tool result sent: %s", tool_use_id) + + # Handle streaming events if needed later + elif isinstance(tool_event, ToolStreamEvent): + logger.debug("Tool stream event: %s", tool_event) + pass + + # Add tool result message to conversation history + if tool_results: + from ....hooks import MessageAddedEvent + + tool_result_message: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + + session.agent.messages.append(tool_result_message) + session.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=session.agent, message=tool_result_message)) + logger.debug("Tool result message added to history: %s", tool_name) + + logger.debug("Tool execution completed: %s", tool_name) + except asyncio.CancelledError: - # Task was cancelled due to interruption - this is expected behavior - logger.debug("Tool task cancelled gracefully: %s (id: %s)", tool_name, tool_id) - raise # Re-raise to properly handle cancellation + logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id) + raise except Exception as e: - logger.error("Tool execution error: %s - %s", tool_use.get("name"), str(e)) + logger.error("Tool execution error: %s - %s", tool_name, str(e)) + # Send error result + error_result: ToolResult = { + "toolUseId": tool_id, + "status": "error", + "content": [{"text": f"Error: {str(e)}"}] + } try: - await session.model_session.send_tool_result(tool_use.get("toolUseId"), {"error": str(e)}) - except Exception as send_error: - logger.error("Tool error send failed: %s", str(send_error)) - - -def _extract_callable_function(tool_func: any) -> any: - """Extract the callable function from different tool object types.""" - if hasattr(tool_func, "_tool_func"): - return tool_func._tool_func - elif hasattr(tool_func, "func"): - return tool_func.func - elif callable(tool_func): - return tool_func - else: - raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") - - -def _create_success_result(tool_use_id: str, result: any) -> dict[str, any]: - """Create a successful tool result.""" - return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} + await session.model_session.send_tool_result(tool_id, error_result) + logger.debug("Error result sent: %s", tool_id) + except Exception: + logger.error("Failed to send error result: %s", tool_id) + pass # Session might be closed -def _create_error_result(tool_use_id: str, error: str) -> dict[str, any]: - """Create an error tool result.""" - return {"toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error}"}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 6cba974e0..882f89eef 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -3,4 +3,11 @@ from .bidirectional_model import BidirectionalModel, BidirectionalModelSession from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession -__all__ = ["BidirectionalModel", "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession"] +__all__ = [ + "BidirectionalModel", + "BidirectionalModelSession", + "NovaSonicBidirectionalModel", + "NovaSonicSession", + "OpenAIRealtimeBidirectionalModel", + "OpenAIRealtimeSession", +] diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 7f7937ef1..a1d61e11a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -121,7 +121,7 @@ async def initialize( init_events = self._build_initialization_events(system_prompt, tools or [], messages) - logger.debug(f"Nova Sonic initialization - sending {len(init_events)} events") + logger.debug("Nova Sonic initialization - sending %d events", len(init_events)) await self._send_initialization_events(init_events) logger.info("Nova Sonic connection initialized successfully") @@ -146,7 +146,7 @@ def _build_initialization_events( async def _send_initialization_events(self, events: list[str]) -> None: """Send initialization events with required delays.""" - for i, event in enumerate(events): + for _i, event in enumerate(events): await self._send_nova_event(event) await asyncio.sleep(EVENT_DELAY) @@ -167,12 +167,12 @@ async def _process_responses(self) -> None: await asyncio.sleep(0.1) continue except Exception as e: - logger.warning(f"Nova Sonic response error: {e}") + logger.warning("Nova Sonic response error: %s", e) await asyncio.sleep(0.1) continue except Exception as e: - logger.error(f"Nova Sonic fatal error: {e}") + logger.error("Nova Sonic fatal error: %s", e) finally: logger.debug("Nova Sonic response processor stopped") @@ -190,7 +190,7 @@ async def _handle_response_data(self, response_data: str) -> None: await self._event_queue.put(nova_event) except json.JSONDecodeError as e: - logger.warning(f"Nova Sonic JSON decode error: {e}") + logger.warning("Nova Sonic JSON decode error: %s", e) def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" @@ -383,11 +383,9 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No self._get_content_end_event(content_name), ] - for i, event in enumerate(events): + for _i, event in enumerate(events): await self._send_nova_event(event) - - async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: @@ -490,7 +488,14 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No # Handle usage events (ignore) elif "usageEvent" in nova_event: - return None + usage_data = nova_event["usageEvent"] + usage_metrics: UsageMetricsEvent = { + "totalTokens": usage_data.get("totalTokens"), + "inputTokens": usage_data.get("totalInputTokens"), + "outputTokens": usage_data.get("totalOutputTokens"), + "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens"), + } + return {"usageMetrics": usage_metrics} # Handle content start events (track role) elif "contentStart" in nova_event: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py new file mode 100644 index 000000000..7c79e3e6c --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -0,0 +1,522 @@ +/Users/mehtarac/Desktop/sdk-python/src/strands/experimental/bidirectional_streaming/models/openai.py + +"""OpenAI Realtime API provider for Strands bidirectional streaming. + +Provides real-time audio and text communication through OpenAI's Realtime API +with WebSocket connections, voice activity detection, and function calling. +""" + +import asyncio +import base64 +import json +import logging +import uuid +from typing import AsyncIterable + +import websockets +from websockets.client import WebSocketClientProtocol +from websockets.exceptions import ConnectionClosed + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + BidirectionalStreamEvent, + TextOutputEvent, + VoiceActivityEvent, +) +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession + +logger = logging.getLogger(__name__) + +# OpenAI Realtime API configuration +OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" +DEFAULT_MODEL = "gpt-realtime" + +AUDIO_FORMAT = {"type": "audio/pcm", "rate": 24000} + +DEFAULT_SESSION_CONFIG = { + "type": "realtime", + "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", + "output_modalities": ["audio"], + "audio": { + "input": { + "format": AUDIO_FORMAT, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500, + }, + }, + "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, + }, +} + + +class OpenAIRealtimeSession(BidirectionalModelSession): + """OpenAI Realtime API session for real-time audio/text streaming. + + Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, + function calling, and event conversion to Strands format. + """ + + def __init__(self, websocket: WebSocketClientProtocol, config: dict[str, any]) -> None: + """Initialize OpenAI Realtime session.""" + self.websocket = websocket + self.config = config + self.session_id = str(uuid.uuid4()) + self._active = True + + self._event_queue = asyncio.Queue() + self._response_task = None + self._function_call_buffer = {} + + logger.debug("OpenAI Realtime session initialized: %s", self.session_id) + + def _require_active(self) -> bool: + """Check if session is active.""" + return self._active + + def _create_text_event(self, text: str, role: str) -> dict[str, any]: + """Create standardized text output event.""" + text_output: TextOutputEvent = {"text": text, "role": role} + return {"textOutput": text_output} + + def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: + """Create standardized voice activity event.""" + voice_activity: VoiceActivityEvent = {"activityType": activity_type} + return {"voiceActivity": voice_activity} + + async def _create_conversation_item(self, item_data: dict) -> None: + """Create conversation item and trigger response.""" + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def initialize( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + ) -> None: + """Initialize session with configuration.""" + try: + session_config = self._build_session_config(system_prompt, tools) + await self._send_event({"type": "session.update", "session": session_config}) + + if messages: + await self._add_conversation_history(messages) + + self._response_task = asyncio.create_task(self._process_responses()) + logger.info("OpenAI Realtime session initialized successfully") + + except Exception as e: + logger.error("Error during OpenAI Realtime initialization: %s", e) + raise + + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: + """Build session configuration for OpenAI Realtime API.""" + config = DEFAULT_SESSION_CONFIG.copy() + + if system_prompt: + config["instructions"] = system_prompt + + if tools: + config["tools"] = self._convert_tools_to_openai_format(tools) + + custom_config = self.config.get("session", {}) + supported_params = { + "type", + "output_modalities", + "instructions", + "voice", + "audio", + "tools", + "tool_choice", + "input_audio_format", + "output_audio_format", + "input_audio_transcription", + "turn_detection", + } + + for key, value in custom_config.items(): + if key in supported_params: + config[key] = value + else: + logger.warning("Ignoring unsupported session parameter: %s", key) + + return config + + def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: + """Convert Strands tool specifications to OpenAI Realtime API format.""" + openai_tools = [] + + for tool in tools: + input_schema = tool["inputSchema"] + if "json" in input_schema: + schema = ( + json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] + ) + else: + schema = input_schema + + # OpenAI Realtime API expects flat structure, not nested under "function" + openai_tool = { + "type": "function", + "name": tool["name"], + "description": tool["description"], + "parameters": schema, + } + openai_tools.append(openai_tool) + + return openai_tools + + async def _add_conversation_history(self, messages: Messages) -> None: + """Add conversation history to the session.""" + for message in messages: + conversation_item = { + "type": "conversation.item.create", + "item": {"type": "message", "role": message["role"], "content": []}, + } + + content = message.get("content", "") + if isinstance(content, str): + conversation_item["item"]["content"].append({"type": "input_text", "text": content}) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + conversation_item["item"]["content"].append( + {"type": "input_text", "text": item.get("text", "")} + ) + + await self._send_event(conversation_item) + + async def _process_responses(self) -> None: + """Process incoming WebSocket messages.""" + logger.debug("OpenAI Realtime response processor started") + + try: + async for message in self.websocket: + if not self._active: + break + + try: + event = json.loads(message) + await self._event_queue.put(event) + except json.JSONDecodeError as e: + logger.warning("Failed to parse OpenAI event: %s", e) + continue + + except ConnectionClosed: + logger.debug("OpenAI Realtime WebSocket connection closed") + except Exception as e: + logger.error("Error in OpenAI Realtime response processing: %s", e) + finally: + self._active = False + logger.debug("OpenAI Realtime response processor stopped") + + async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive OpenAI events and convert to Strands format.""" + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.session_id, + "metadata": {"provider": "openai_realtime", "model": self.config.get("model", DEFAULT_MODEL)}, + } + yield {"BidirectionalConnectionStart": connection_start} + + try: + while self._active: + try: + openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + provider_event = self._convert_openai_event(openai_event) + if provider_event: + yield provider_event + except asyncio.TimeoutError: + continue + + except Exception as e: + logger.error("Error receiving OpenAI Realtime event: %s", e) + finally: + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.session_id, + "reason": "connection_complete", + "metadata": {"provider": "openai_realtime"}, + } + yield {"BidirectionalConnectionEnd": connection_end} + + def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] | None: + """Convert OpenAI events to Strands format.""" + event_type = openai_event.get("type") + + # Audio output + if event_type == "response.output_audio.delta": + audio_data = base64.b64decode(openai_event["delta"]) + audio_output: AudioOutputEvent = { + "audioData": audio_data, + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + "encoding": None, + } + return {"audioOutput": audio_output} + + # Text output using helper method + elif event_type == "response.output_text.delta": + return self._create_text_event(openai_event["delta"], "assistant") + + elif event_type == "response.output_audio_transcript.delta": + return self._create_text_event(openai_event["delta"], "assistant") + + # User transcription + elif event_type == "conversation.item.input_audio_transcription.delta": + transcript_delta = openai_event.get("delta", "") + return self._create_text_event(transcript_delta, "user") if transcript_delta.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.completed": + transcript = openai_event.get("transcript", "") + return self._create_text_event(transcript, "user") if transcript.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.segment": + segment_data = openai_event.get("segment", {}) + text = segment_data.get("text", "") + return self._create_text_event(text, "user") if text.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.failed": + error_info = openai_event.get("error", {}) + logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) + return None + + # Function call processing + elif event_type == "response.function_call_arguments.delta": + call_id = openai_event.get("call_id") + delta = openai_event.get("delta", "") + if call_id: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} + else: + self._function_call_buffer[call_id]["arguments"] += delta + return None + + elif event_type == "response.function_call_arguments.done": + call_id = openai_event.get("call_id") + if call_id and call_id in self._function_call_buffer: + function_call = self._function_call_buffer[call_id] + try: + tool_use: ToolUse = { + "toolUseId": call_id, + "name": function_call["name"], + "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, + } + del self._function_call_buffer[call_id] + return {"toolUse": tool_use} + except (json.JSONDecodeError, KeyError) as e: + logger.warning("Error parsing function arguments for %s: %s", call_id, e) + del self._function_call_buffer[call_id] + return None + + # Voice activity detection using helper method + elif event_type == "input_audio_buffer.speech_started": + return self._create_voice_activity_event("speech_started") + elif event_type == "input_audio_buffer.speech_stopped": + return self._create_voice_activity_event("speech_stopped") + elif event_type == "input_audio_buffer.timeout_triggered": + return self._create_voice_activity_event("timeout") + + # Lifecycle events (log only) + elif event_type == "conversation.item.retrieve": + item = openai_event.get("item", {}) + logger.debug("OpenAI conversation item retrieved: %s", item.get("id")) + return None + + elif event_type == "conversation.item.added": + logger.debug("OpenAI conversation item added: %s", openai_event.get("item", {}).get("id")) + return None + + elif event_type == "conversation.item.done": + logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) + + item = openai_event.get("item", {}) + if item.get("type") == "message" and item.get("role") == "assistant": + content_parts = item.get("content", []) + if content_parts: + message_content = [] + for content_part in content_parts: + if content_part.get("type") == "output_text": + message_content.append({"type": "text", "text": content_part.get("text", "")}) + elif content_part.get("type") == "output_audio": + transcript = content_part.get("transcript", "") + if transcript: + message_content.append({"type": "text", "text": transcript}) + + if message_content: + message = {"role": "assistant", "content": message_content} + return {"messageStop": {"message": message}} + return None + + elif event_type in [ + "response.output_item.added", + "response.output_item.done", + "response.content_part.added", + "response.content_part.done", + ]: + item_data = openai_event.get("item") or openai_event.get("part") + logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") + + # Track function call names from response.output_item.added + if event_type == "response.output_item.added": + item = openai_event.get("item", {}) + if item.get("type") == "function_call": + call_id = item.get("call_id") + function_name = item.get("name") + if call_id and function_name: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = { + "call_id": call_id, + "name": function_name, + "arguments": "", + } + else: + self._function_call_buffer[call_id]["name"] = function_name + return None + + elif event_type in [ + "input_audio_buffer.committed", + "input_audio_buffer.cleared", + "session.created", + "session.updated", + ]: + logger.debug("OpenAI %s event", event_type) + return None + + elif event_type == "error": + logger.error("OpenAI Realtime error: %s", openai_event.get("error", {})) + return None + + else: + logger.debug("Unhandled OpenAI event type: %s", event_type) + return None + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content to OpenAI for processing.""" + if not self._require_active(): + return + + audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") + await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content to OpenAI for processing.""" + if not self._require_active(): + return + + item_data = {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]} + await self._create_conversation_item(item_data) + + async def send_interrupt(self) -> None: + """Send interruption signal to OpenAI.""" + if not self._require_active(): + return + + await self._send_event({"type": "response.cancel"}) + + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: + """Send tool result back to OpenAI.""" + if not self._require_active(): + return + + logger.debug("OpenAI tool result send: %s", tool_use_id) + result_text = json.dumps(result) if not isinstance(result, str) else result + + item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_text} + await self._create_conversation_item(item_data) + + async def close(self) -> None: + """Close session and cleanup resources.""" + if not self._active: + return + + logger.debug("OpenAI Realtime cleanup - starting connection close") + self._active = False + + if self._response_task and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + await self.websocket.close() + except Exception as e: + logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) + + logger.debug("OpenAI Realtime connection closed") + + async def _send_event(self, event: dict[str, any]) -> None: + """Send event to OpenAI via WebSocket.""" + try: + message = json.dumps(event) + await self.websocket.send(message) + logger.debug("Sent OpenAI event: %s", event.get("type")) + except Exception as e: + logger.error("Error sending OpenAI event: %s", e) + raise + + +class OpenAIRealtimeBidirectionalModel(BidirectionalModel): + """OpenAI Realtime API provider for Strands bidirectional streaming. + + Provides real-time audio/text communication through OpenAI's Realtime API + with WebSocket connections, voice activity detection, and function calling. + """ + + def __init__(self, model: str = DEFAULT_MODEL, api_key: str | None = None, **config: any) -> None: + """Initialize OpenAI Realtime bidirectional model.""" + self.model = model + self.api_key = api_key + self.config = config + + import os + + if not self.api_key: + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError( + "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." + ) + + logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) + + async def create_bidirectional_connection( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> BidirectionalModelSession: + """Create bidirectional connection to OpenAI Realtime API.""" + logger.info("Creating OpenAI Realtime connection...") + + try: + url = f"{OPENAI_REALTIME_URL}?model={self.model}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if "organization" in self.config: + headers.append(("OpenAI-Organization", self.config["organization"])) + if "project" in self.config: + headers.append(("OpenAI-Project", self.config["project"])) + + websocket = await websockets.connect(url, additional_headers=headers) + logger.info("WebSocket connected successfully") + + session = OpenAIRealtimeSession(websocket, self.config) + await session.initialize(system_prompt, tools, messages) + + logger.info("OpenAI Realtime connection established") + return session + + except Exception as e: + logger.error("OpenAI connection error: %s", e) + raise diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py new file mode 100644 index 000000000..5ce4b8cb2 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 +"""Test OpenAI Realtime API speech-to-speech interaction.""" + +import asyncio +import os +import sys +import time +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent / "src")) + +import pyaudio +from strands_tools import calculator + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel + + +def test_direct_tool_calling(): + """Test direct tool calling functionality.""" + print("Testing direct tool calling...") + + try: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("OPENAI_API_KEY not set - skipping test") + return + + model = OpenAIRealtimeBidirectionalModel(model="gpt-4o-realtime-preview", api_key=api_key) + agent = BidirectionalAgent(model=model, tools=[calculator]) + + # Test calculator + result = agent.tool.calculator(expression="2 * 3") + content = result.get("content", [{}])[0].get("text", "") + print(f"Result: {content}") + print("Test completed") + + except Exception as e: + print(f"Test failed: {e}") + + +async def play(context): + """Handle audio playback with interruption support.""" + audio = pyaudio.PyAudio() + + try: + speaker = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # OpenAI Realtime uses 24kHz + output=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + # Check for interruption + if context.get("interrupted", False): + # Clear audio queue on interruption + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get audio data with timeout + try: + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + # Play in chunks to allow interruption + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + if context.get("interrupted", False) or not context["active"]: + break + + chunk = audio_data[i:i + chunk_size] + speaker.write(chunk) + await asyncio.sleep(0.001) # Brief pause for responsiveness + + except asyncio.TimeoutError: + continue + + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Audio playback error: {e}") + finally: + try: + speaker.close() + except Exception: + pass + audio.terminate() + + +async def record(context): + """Handle microphone recording.""" + audio = pyaudio.PyAudio() + + try: + microphone = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # Match OpenAI's expected input rate + input=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + await context["audio_in"].put(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Microphone recording error: {e}") + finally: + try: + microphone.close() + except Exception: + pass + audio.terminate() + + +async def receive(agent, context): + """Handle events from the agent.""" + try: + async for event in agent.receive(): + if not context["active"]: + break + + # Handle audio output + if "audioOutput" in event: + audio_data = event["audioOutput"]["audioData"] + + if not context.get("interrupted", False): + await context["audio_out"].put(audio_data) + + # Handle text output (transcripts) + elif "textOutput" in event: + text_output = event["textOutput"] + role = text_output.get("role", "assistant") + text = text_output.get("text", "").strip() + + if text: + if role == "user": + print(f"User: {text}") + elif role == "assistant": + print(f"Assistant: {text}") + + # Handle interruption detection + elif "interruptionDetected" in event: + context["interrupted"] = True + + # Handle connection events + elif "BidirectionalConnectionStart" in event: + pass # Silent connection start + elif "BidirectionalConnectionEnd" in event: + context["active"] = False + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Receive handler error: {e}") + finally: + pass + + +async def send(agent, context): + """Send audio from microphone to agent.""" + try: + while context["active"]: + try: + audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) + + # Create audio event in expected format + audio_event = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 24000, + "channels": 1 + } + + await agent.send(audio_event) + + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Send handler error: {e}") + finally: + pass + + +async def main(): + """Main test function for OpenAI voice chat.""" + print("Starting OpenAI Realtime API test...") + + # Check API key + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("OPENAI_API_KEY environment variable not set") + return False + + # Check audio system + try: + audio = pyaudio.PyAudio() + audio.terminate() + except Exception as e: + print(f"Audio system error: {e}") + return False + + # Create OpenAI model + model = OpenAIRealtimeBidirectionalModel( + model="gpt-4o-realtime-preview", + api_key=api_key, + session={ + "output_modalities": ["audio"], + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": 24000}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "silence_duration_ms": 700 + } + }, + "output": { + "format": {"type": "audio/pcm", "rate": 24000}, + "voice": "alloy" + } + } + } + ) + + # Create agent + agent = BidirectionalAgent( + model=model, + tools=[calculator], + system_prompt=( + "You are a helpful voice assistant. Keep your responses brief and natural. " + "Say hello when you first connect." + ) + ) + + # Start the session + await agent.start() + + # Create shared context + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "interrupted": False, + "start_time": time.time() + } + + print("Speak into your microphone. Press Ctrl+C to stop.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + return_exceptions=True + ) + + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + except Exception as e: + print(f"\nError during voice chat: {e}") + finally: + print("Cleaning up...") + context["active"] = False + + try: + await agent.end() + except Exception as e: + print(f"Cleanup error: {e}") + + return True + + +if __name__ == "__main__": + # Test direct tool calling first + print("OpenAI Realtime API Test Suite") + test_direct_tool_calling() + + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Test error: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py index b31607966..8c3ae3b4c 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -10,6 +10,7 @@ # Add the src directory to Python path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) +import os import time import pyaudio @@ -19,6 +20,29 @@ from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +def test_direct_tools(): + """Test direct tool calling.""" + print("Testing direct tool calling...") + + # Check AWS credentials + if not all([os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY")]): + print("AWS credentials not set - skipping test") + return + + try: + model = NovaSonicBidirectionalModel() + agent = BidirectionalAgent(model=model, tools=[calculator]) + + # Test calculator + result = agent.tool.calculator(expression="2 * 3") + content = result.get("content", [{}])[0].get("text", "") + print(f"Result: {content}") + print("Test completed") + + except Exception as e: + print(f"Test failed: {e}") + + async def play(context): """Play audio output with responsive interruption support.""" audio = pyaudio.PyAudio() @@ -195,4 +219,7 @@ async def main(duration=180): if __name__ == "__main__": + # Test direct tool calling first + test_direct_tools() + asyncio.run(main()) From ee12db36c34e786fef880d9699d6696d41ffa14c Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 14 Oct 2025 08:41:45 -0400 Subject: [PATCH 2/2] feat(tool_executor): Plug tool executor into bidirectional streaming implementation --- .../bidirectional_streaming/__init__.py | 4 - .../models/__init__.py | 2 - .../models/novasonic.py | 1 + .../bidirectional_streaming/models/openai.py | 522 ------------------ ...al_streaming.py => test_bidi_novasonic.py} | 0 .../tests/test_bidi_openai.py | 317 ----------- .../types/bidirectional_streaming.py | 35 +- 7 files changed, 29 insertions(+), 852 deletions(-) delete mode 100644 src/strands/experimental/bidirectional_streaming/models/openai.py rename src/strands/experimental/bidirectional_streaming/tests/{test_bidirectional_streaming.py => test_bidi_novasonic.py} (100%) delete mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 844a8a1f8..0f842ee9f 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -8,7 +8,6 @@ # Model providers - What users need to create models from .models.novasonic import NovaSonicBidirectionalModel -from .models.openai import OpenAIRealtimeBidirectionalModel # Event types - For type hints and event handling from .types.bidirectional_streaming import ( @@ -18,7 +17,6 @@ InterruptionDetectedEvent, TextOutputEvent, UsageMetricsEvent, - VoiceActivityEvent, ) __all__ = [ @@ -26,14 +24,12 @@ "BidirectionalAgent", # Model providers "NovaSonicBidirectionalModel", - "OpenAIRealtimeBidirectionalModel", # Event types "AudioInputEvent", "AudioOutputEvent", "TextOutputEvent", "InterruptionDetectedEvent", "BidirectionalStreamEvent", - "VoiceActivityEvent", "UsageMetricsEvent", # Model interface "BidirectionalModel", diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 882f89eef..3a785e98a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -8,6 +8,4 @@ "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession", - "OpenAIRealtimeBidirectionalModel", - "OpenAIRealtimeSession", ] diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index a1d61e11a..7f35a3c1c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -35,6 +35,7 @@ BidirectionalConnectionStartEvent, InterruptionDetectedEvent, TextOutputEvent, + UsageMetricsEvent ) from .bidirectional_model import BidirectionalModel, BidirectionalModelSession diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py deleted file mode 100644 index 7c79e3e6c..000000000 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ /dev/null @@ -1,522 +0,0 @@ -/Users/mehtarac/Desktop/sdk-python/src/strands/experimental/bidirectional_streaming/models/openai.py - -"""OpenAI Realtime API provider for Strands bidirectional streaming. - -Provides real-time audio and text communication through OpenAI's Realtime API -with WebSocket connections, voice activity detection, and function calling. -""" - -import asyncio -import base64 -import json -import logging -import uuid -from typing import AsyncIterable - -import websockets -from websockets.client import WebSocketClientProtocol -from websockets.exceptions import ConnectionClosed - -from ....types.content import Messages -from ....types.tools import ToolSpec, ToolUse -from ..types.bidirectional_streaming import ( - AudioInputEvent, - AudioOutputEvent, - BidirectionalConnectionEndEvent, - BidirectionalConnectionStartEvent, - BidirectionalStreamEvent, - TextOutputEvent, - VoiceActivityEvent, -) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession - -logger = logging.getLogger(__name__) - -# OpenAI Realtime API configuration -OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" -DEFAULT_MODEL = "gpt-realtime" - -AUDIO_FORMAT = {"type": "audio/pcm", "rate": 24000} - -DEFAULT_SESSION_CONFIG = { - "type": "realtime", - "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", - "output_modalities": ["audio"], - "audio": { - "input": { - "format": AUDIO_FORMAT, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "prefix_padding_ms": 300, - "silence_duration_ms": 500, - }, - }, - "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, - }, -} - - -class OpenAIRealtimeSession(BidirectionalModelSession): - """OpenAI Realtime API session for real-time audio/text streaming. - - Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, - function calling, and event conversion to Strands format. - """ - - def __init__(self, websocket: WebSocketClientProtocol, config: dict[str, any]) -> None: - """Initialize OpenAI Realtime session.""" - self.websocket = websocket - self.config = config - self.session_id = str(uuid.uuid4()) - self._active = True - - self._event_queue = asyncio.Queue() - self._response_task = None - self._function_call_buffer = {} - - logger.debug("OpenAI Realtime session initialized: %s", self.session_id) - - def _require_active(self) -> bool: - """Check if session is active.""" - return self._active - - def _create_text_event(self, text: str, role: str) -> dict[str, any]: - """Create standardized text output event.""" - text_output: TextOutputEvent = {"text": text, "role": role} - return {"textOutput": text_output} - - def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: - """Create standardized voice activity event.""" - voice_activity: VoiceActivityEvent = {"activityType": activity_type} - return {"voiceActivity": voice_activity} - - async def _create_conversation_item(self, item_data: dict) -> None: - """Create conversation item and trigger response.""" - await self._send_event({"type": "conversation.item.create", "item": item_data}) - await self._send_event({"type": "response.create"}) - - async def initialize( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - ) -> None: - """Initialize session with configuration.""" - try: - session_config = self._build_session_config(system_prompt, tools) - await self._send_event({"type": "session.update", "session": session_config}) - - if messages: - await self._add_conversation_history(messages) - - self._response_task = asyncio.create_task(self._process_responses()) - logger.info("OpenAI Realtime session initialized successfully") - - except Exception as e: - logger.error("Error during OpenAI Realtime initialization: %s", e) - raise - - def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: - """Build session configuration for OpenAI Realtime API.""" - config = DEFAULT_SESSION_CONFIG.copy() - - if system_prompt: - config["instructions"] = system_prompt - - if tools: - config["tools"] = self._convert_tools_to_openai_format(tools) - - custom_config = self.config.get("session", {}) - supported_params = { - "type", - "output_modalities", - "instructions", - "voice", - "audio", - "tools", - "tool_choice", - "input_audio_format", - "output_audio_format", - "input_audio_transcription", - "turn_detection", - } - - for key, value in custom_config.items(): - if key in supported_params: - config[key] = value - else: - logger.warning("Ignoring unsupported session parameter: %s", key) - - return config - - def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: - """Convert Strands tool specifications to OpenAI Realtime API format.""" - openai_tools = [] - - for tool in tools: - input_schema = tool["inputSchema"] - if "json" in input_schema: - schema = ( - json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] - ) - else: - schema = input_schema - - # OpenAI Realtime API expects flat structure, not nested under "function" - openai_tool = { - "type": "function", - "name": tool["name"], - "description": tool["description"], - "parameters": schema, - } - openai_tools.append(openai_tool) - - return openai_tools - - async def _add_conversation_history(self, messages: Messages) -> None: - """Add conversation history to the session.""" - for message in messages: - conversation_item = { - "type": "conversation.item.create", - "item": {"type": "message", "role": message["role"], "content": []}, - } - - content = message.get("content", "") - if isinstance(content, str): - conversation_item["item"]["content"].append({"type": "input_text", "text": content}) - elif isinstance(content, list): - for item in content: - if isinstance(item, dict) and item.get("type") == "text": - conversation_item["item"]["content"].append( - {"type": "input_text", "text": item.get("text", "")} - ) - - await self._send_event(conversation_item) - - async def _process_responses(self) -> None: - """Process incoming WebSocket messages.""" - logger.debug("OpenAI Realtime response processor started") - - try: - async for message in self.websocket: - if not self._active: - break - - try: - event = json.loads(message) - await self._event_queue.put(event) - except json.JSONDecodeError as e: - logger.warning("Failed to parse OpenAI event: %s", e) - continue - - except ConnectionClosed: - logger.debug("OpenAI Realtime WebSocket connection closed") - except Exception as e: - logger.error("Error in OpenAI Realtime response processing: %s", e) - finally: - self._active = False - logger.debug("OpenAI Realtime response processor stopped") - - async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive OpenAI events and convert to Strands format.""" - connection_start: BidirectionalConnectionStartEvent = { - "connectionId": self.session_id, - "metadata": {"provider": "openai_realtime", "model": self.config.get("model", DEFAULT_MODEL)}, - } - yield {"BidirectionalConnectionStart": connection_start} - - try: - while self._active: - try: - openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) - provider_event = self._convert_openai_event(openai_event) - if provider_event: - yield provider_event - except asyncio.TimeoutError: - continue - - except Exception as e: - logger.error("Error receiving OpenAI Realtime event: %s", e) - finally: - connection_end: BidirectionalConnectionEndEvent = { - "connectionId": self.session_id, - "reason": "connection_complete", - "metadata": {"provider": "openai_realtime"}, - } - yield {"BidirectionalConnectionEnd": connection_end} - - def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] | None: - """Convert OpenAI events to Strands format.""" - event_type = openai_event.get("type") - - # Audio output - if event_type == "response.output_audio.delta": - audio_data = base64.b64decode(openai_event["delta"]) - audio_output: AudioOutputEvent = { - "audioData": audio_data, - "format": "pcm", - "sampleRate": 24000, - "channels": 1, - "encoding": None, - } - return {"audioOutput": audio_output} - - # Text output using helper method - elif event_type == "response.output_text.delta": - return self._create_text_event(openai_event["delta"], "assistant") - - elif event_type == "response.output_audio_transcript.delta": - return self._create_text_event(openai_event["delta"], "assistant") - - # User transcription - elif event_type == "conversation.item.input_audio_transcription.delta": - transcript_delta = openai_event.get("delta", "") - return self._create_text_event(transcript_delta, "user") if transcript_delta.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.completed": - transcript = openai_event.get("transcript", "") - return self._create_text_event(transcript, "user") if transcript.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.segment": - segment_data = openai_event.get("segment", {}) - text = segment_data.get("text", "") - return self._create_text_event(text, "user") if text.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.failed": - error_info = openai_event.get("error", {}) - logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) - return None - - # Function call processing - elif event_type == "response.function_call_arguments.delta": - call_id = openai_event.get("call_id") - delta = openai_event.get("delta", "") - if call_id: - if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} - else: - self._function_call_buffer[call_id]["arguments"] += delta - return None - - elif event_type == "response.function_call_arguments.done": - call_id = openai_event.get("call_id") - if call_id and call_id in self._function_call_buffer: - function_call = self._function_call_buffer[call_id] - try: - tool_use: ToolUse = { - "toolUseId": call_id, - "name": function_call["name"], - "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, - } - del self._function_call_buffer[call_id] - return {"toolUse": tool_use} - except (json.JSONDecodeError, KeyError) as e: - logger.warning("Error parsing function arguments for %s: %s", call_id, e) - del self._function_call_buffer[call_id] - return None - - # Voice activity detection using helper method - elif event_type == "input_audio_buffer.speech_started": - return self._create_voice_activity_event("speech_started") - elif event_type == "input_audio_buffer.speech_stopped": - return self._create_voice_activity_event("speech_stopped") - elif event_type == "input_audio_buffer.timeout_triggered": - return self._create_voice_activity_event("timeout") - - # Lifecycle events (log only) - elif event_type == "conversation.item.retrieve": - item = openai_event.get("item", {}) - logger.debug("OpenAI conversation item retrieved: %s", item.get("id")) - return None - - elif event_type == "conversation.item.added": - logger.debug("OpenAI conversation item added: %s", openai_event.get("item", {}).get("id")) - return None - - elif event_type == "conversation.item.done": - logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) - - item = openai_event.get("item", {}) - if item.get("type") == "message" and item.get("role") == "assistant": - content_parts = item.get("content", []) - if content_parts: - message_content = [] - for content_part in content_parts: - if content_part.get("type") == "output_text": - message_content.append({"type": "text", "text": content_part.get("text", "")}) - elif content_part.get("type") == "output_audio": - transcript = content_part.get("transcript", "") - if transcript: - message_content.append({"type": "text", "text": transcript}) - - if message_content: - message = {"role": "assistant", "content": message_content} - return {"messageStop": {"message": message}} - return None - - elif event_type in [ - "response.output_item.added", - "response.output_item.done", - "response.content_part.added", - "response.content_part.done", - ]: - item_data = openai_event.get("item") or openai_event.get("part") - logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") - - # Track function call names from response.output_item.added - if event_type == "response.output_item.added": - item = openai_event.get("item", {}) - if item.get("type") == "function_call": - call_id = item.get("call_id") - function_name = item.get("name") - if call_id and function_name: - if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = { - "call_id": call_id, - "name": function_name, - "arguments": "", - } - else: - self._function_call_buffer[call_id]["name"] = function_name - return None - - elif event_type in [ - "input_audio_buffer.committed", - "input_audio_buffer.cleared", - "session.created", - "session.updated", - ]: - logger.debug("OpenAI %s event", event_type) - return None - - elif event_type == "error": - logger.error("OpenAI Realtime error: %s", openai_event.get("error", {})) - return None - - else: - logger.debug("Unhandled OpenAI event type: %s", event_type) - return None - - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to OpenAI for processing.""" - if not self._require_active(): - return - - audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") - await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) - - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content to OpenAI for processing.""" - if not self._require_active(): - return - - item_data = {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]} - await self._create_conversation_item(item_data) - - async def send_interrupt(self) -> None: - """Send interruption signal to OpenAI.""" - if not self._require_active(): - return - - await self._send_event({"type": "response.cancel"}) - - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool result back to OpenAI.""" - if not self._require_active(): - return - - logger.debug("OpenAI tool result send: %s", tool_use_id) - result_text = json.dumps(result) if not isinstance(result, str) else result - - item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_text} - await self._create_conversation_item(item_data) - - async def close(self) -> None: - """Close session and cleanup resources.""" - if not self._active: - return - - logger.debug("OpenAI Realtime cleanup - starting connection close") - self._active = False - - if self._response_task and not self._response_task.done(): - self._response_task.cancel() - try: - await self._response_task - except asyncio.CancelledError: - pass - - try: - await self.websocket.close() - except Exception as e: - logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) - - logger.debug("OpenAI Realtime connection closed") - - async def _send_event(self, event: dict[str, any]) -> None: - """Send event to OpenAI via WebSocket.""" - try: - message = json.dumps(event) - await self.websocket.send(message) - logger.debug("Sent OpenAI event: %s", event.get("type")) - except Exception as e: - logger.error("Error sending OpenAI event: %s", e) - raise - - -class OpenAIRealtimeBidirectionalModel(BidirectionalModel): - """OpenAI Realtime API provider for Strands bidirectional streaming. - - Provides real-time audio/text communication through OpenAI's Realtime API - with WebSocket connections, voice activity detection, and function calling. - """ - - def __init__(self, model: str = DEFAULT_MODEL, api_key: str | None = None, **config: any) -> None: - """Initialize OpenAI Realtime bidirectional model.""" - self.model = model - self.api_key = api_key - self.config = config - - import os - - if not self.api_key: - self.api_key = os.getenv("OPENAI_API_KEY") - if not self.api_key: - raise ValueError( - "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." - ) - - logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) - - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create bidirectional connection to OpenAI Realtime API.""" - logger.info("Creating OpenAI Realtime connection...") - - try: - url = f"{OPENAI_REALTIME_URL}?model={self.model}" - - headers = [("Authorization", f"Bearer {self.api_key}")] - if "organization" in self.config: - headers.append(("OpenAI-Organization", self.config["organization"])) - if "project" in self.config: - headers.append(("OpenAI-Project", self.config["project"])) - - websocket = await websockets.connect(url, additional_headers=headers) - logger.info("WebSocket connected successfully") - - session = OpenAIRealtimeSession(websocket, self.config) - await session.initialize(system_prompt, tools, messages) - - logger.info("OpenAI Realtime connection established") - return session - - except Exception as e: - logger.error("OpenAI connection error: %s", e) - raise diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py similarity index 100% rename from src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py rename to src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py deleted file mode 100644 index 5ce4b8cb2..000000000 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py +++ /dev/null @@ -1,317 +0,0 @@ -#!/usr/bin/env python3 -"""Test OpenAI Realtime API speech-to-speech interaction.""" - -import asyncio -import os -import sys -import time -from pathlib import Path - -# Add the src directory to Python path -sys.path.insert(0, str(Path(__file__).parent / "src")) - -import pyaudio -from strands_tools import calculator - -from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel - - -def test_direct_tool_calling(): - """Test direct tool calling functionality.""" - print("Testing direct tool calling...") - - try: - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - print("OPENAI_API_KEY not set - skipping test") - return - - model = OpenAIRealtimeBidirectionalModel(model="gpt-4o-realtime-preview", api_key=api_key) - agent = BidirectionalAgent(model=model, tools=[calculator]) - - # Test calculator - result = agent.tool.calculator(expression="2 * 3") - content = result.get("content", [{}])[0].get("text", "") - print(f"Result: {content}") - print("Test completed") - - except Exception as e: - print(f"Test failed: {e}") - - -async def play(context): - """Handle audio playback with interruption support.""" - audio = pyaudio.PyAudio() - - try: - speaker = audio.open( - format=pyaudio.paInt16, - channels=1, - rate=24000, # OpenAI Realtime uses 24kHz - output=True, - frames_per_buffer=1024, - ) - - while context["active"]: - try: - # Check for interruption - if context.get("interrupted", False): - # Clear audio queue on interruption - while not context["audio_out"].empty(): - try: - context["audio_out"].get_nowait() - except asyncio.QueueEmpty: - break - - context["interrupted"] = False - await asyncio.sleep(0.05) - continue - - # Get audio data with timeout - try: - audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) - - if audio_data and context["active"]: - # Play in chunks to allow interruption - chunk_size = 1024 - for i in range(0, len(audio_data), chunk_size): - if context.get("interrupted", False) or not context["active"]: - break - - chunk = audio_data[i:i + chunk_size] - speaker.write(chunk) - await asyncio.sleep(0.001) # Brief pause for responsiveness - - except asyncio.TimeoutError: - continue - - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Audio playback error: {e}") - finally: - try: - speaker.close() - except Exception: - pass - audio.terminate() - - -async def record(context): - """Handle microphone recording.""" - audio = pyaudio.PyAudio() - - try: - microphone = audio.open( - format=pyaudio.paInt16, - channels=1, - rate=24000, # Match OpenAI's expected input rate - input=True, - frames_per_buffer=1024, - ) - - while context["active"]: - try: - audio_bytes = microphone.read(1024, exception_on_overflow=False) - await context["audio_in"].put(audio_bytes) - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Microphone recording error: {e}") - finally: - try: - microphone.close() - except Exception: - pass - audio.terminate() - - -async def receive(agent, context): - """Handle events from the agent.""" - try: - async for event in agent.receive(): - if not context["active"]: - break - - # Handle audio output - if "audioOutput" in event: - audio_data = event["audioOutput"]["audioData"] - - if not context.get("interrupted", False): - await context["audio_out"].put(audio_data) - - # Handle text output (transcripts) - elif "textOutput" in event: - text_output = event["textOutput"] - role = text_output.get("role", "assistant") - text = text_output.get("text", "").strip() - - if text: - if role == "user": - print(f"User: {text}") - elif role == "assistant": - print(f"Assistant: {text}") - - # Handle interruption detection - elif "interruptionDetected" in event: - context["interrupted"] = True - - # Handle connection events - elif "BidirectionalConnectionStart" in event: - pass # Silent connection start - elif "BidirectionalConnectionEnd" in event: - context["active"] = False - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Receive handler error: {e}") - finally: - pass - - -async def send(agent, context): - """Send audio from microphone to agent.""" - try: - while context["active"]: - try: - audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) - - # Create audio event in expected format - audio_event = { - "audioData": audio_bytes, - "format": "pcm", - "sampleRate": 24000, - "channels": 1 - } - - await agent.send(audio_event) - - except asyncio.TimeoutError: - continue - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Send handler error: {e}") - finally: - pass - - -async def main(): - """Main test function for OpenAI voice chat.""" - print("Starting OpenAI Realtime API test...") - - # Check API key - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - print("OPENAI_API_KEY environment variable not set") - return False - - # Check audio system - try: - audio = pyaudio.PyAudio() - audio.terminate() - except Exception as e: - print(f"Audio system error: {e}") - return False - - # Create OpenAI model - model = OpenAIRealtimeBidirectionalModel( - model="gpt-4o-realtime-preview", - api_key=api_key, - session={ - "output_modalities": ["audio"], - "audio": { - "input": { - "format": {"type": "audio/pcm", "rate": 24000}, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "silence_duration_ms": 700 - } - }, - "output": { - "format": {"type": "audio/pcm", "rate": 24000}, - "voice": "alloy" - } - } - } - ) - - # Create agent - agent = BidirectionalAgent( - model=model, - tools=[calculator], - system_prompt=( - "You are a helpful voice assistant. Keep your responses brief and natural. " - "Say hello when you first connect." - ) - ) - - # Start the session - await agent.start() - - # Create shared context - context = { - "active": True, - "audio_in": asyncio.Queue(), - "audio_out": asyncio.Queue(), - "interrupted": False, - "start_time": time.time() - } - - print("Speak into your microphone. Press Ctrl+C to stop.") - - try: - # Run all tasks concurrently - await asyncio.gather( - play(context), - record(context), - receive(agent, context), - send(agent, context), - return_exceptions=True - ) - - except KeyboardInterrupt: - print("\nInterrupted by user") - except asyncio.CancelledError: - print("\nTest cancelled") - except Exception as e: - print(f"\nError during voice chat: {e}") - finally: - print("Cleaning up...") - context["active"] = False - - try: - await agent.end() - except Exception as e: - print(f"Cleanup error: {e}") - - return True - - -if __name__ == "__main__": - # Test direct tool calling first - print("OpenAI Realtime API Test Suite") - test_direct_tool_calling() - - try: - asyncio.run(main()) - except KeyboardInterrupt: - print("\nTest interrupted by user") - except Exception as e: - print(f"Test error: {e}") - import traceback - traceback.print_exc() \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 01d72356a..c0f6eb209 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -116,10 +116,28 @@ class BidirectionalConnectionEndEvent(TypedDict): metadata: Provider-specific connection metadata. """ - reason: Literal["user_request", "timeout", "error"] + reason: Literal["user_request", "timeout", "error", "connection_complete"] connectionId: Optional[str] metadata: Optional[Dict[str, Any]] +class UsageMetricsEvent(TypedDict): + """Token usage and performance tracking. + + Provides standardized usage metrics across providers for cost monitoring + and performance optimization. + + Attributes: + totalTokens: Total tokens used in the interaction. + inputTokens: Tokens used for input processing. + outputTokens: Tokens used for output generation. + audioTokens: Tokens used specifically for audio processing. + """ + + totalTokens: Optional[int] + inputTokens: Optional[int] + outputTokens: Optional[int] + audioTokens: Optional[int] + class BidirectionalStreamEvent(StreamEvent, total=False): """Bidirectional stream event extending existing StreamEvent. @@ -134,11 +152,14 @@ class BidirectionalStreamEvent(StreamEvent, total=False): interruptionDetected: User interruption detection. BidirectionalConnectionStart: connection start event. BidirectionalConnectionEnd: connection end event. + usageMetrics: Token usage and performance metrics. """ - audioOutput: AudioOutputEvent - audioInput: AudioInputEvent - textOutput: TextOutputEvent - interruptionDetected: InterruptionDetectedEvent - BidirectionalConnectionStart: BidirectionalConnectionStartEvent - BidirectionalConnectionEnd: BidirectionalConnectionEndEvent + audioOutput: Optional[AudioOutputEvent] + audioInput: Optional[AudioInputEvent] + textOutput: Optional[TextOutputEvent] + interruptionDetected: Optional[InterruptionDetectedEvent] + BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent] + BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent] + usageMetrics: Optional[UsageMetricsEvent] +