Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/strands/experimental/bidirectional_streaming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

# Main components - Primary user interface
from .agent.agent import BidirectionalAgent

# Advanced interfaces (for custom implementations)
from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession
from .models.base_model import BidirectionalModel
from .models.base_session import BidirectionalModelSession

# Model providers - What users need to create models
from .models.novasonic import NovaSonicBidirectionalModel
Expand Down
19 changes: 9 additions & 10 deletions src/strands/experimental/bidirectional_streaming/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@
from ....types.tools import ToolResult, ToolUse
from ....types.traces import AttributeValue
from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection
from ..models.bidirectional_model import BidirectionalModel
from ..models.base_model import BidirectionalModel
from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent


logger = logging.getLogger(__name__)

_DEFAULT_AGENT_NAME = "Strands Agents"
Expand Down Expand Up @@ -81,7 +80,7 @@ def caller(

Args:
user_message_override: Optional custom message to record instead of default
record_direct_tool_call: Whether to record direct tool calls in message history.
record_direct_tool_call: Whether to record direct tool calls in message history.
For bidirectional agents, this is always True to maintain conversation history.
**kwargs: Keyword arguments to pass to the tool.

Expand Down Expand Up @@ -186,12 +185,12 @@ def __init__(
self.model = model
self.system_prompt = system_prompt
self.messages = messages or []

# Agent identification
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
self.name = name or _DEFAULT_AGENT_NAME
self.description = description

# Tool execution configuration
self.record_direct_tool_call = record_direct_tool_call
self.load_tools_from_directory = load_tools_from_directory
Expand All @@ -207,25 +206,25 @@ def __init__(

# Initialize tool registry
self.tool_registry = ToolRegistry()

if tools is not None:
self.tool_registry.process_tools(tools)

self.tool_registry.initialize_tools(self.load_tools_from_directory)

# Initialize tool watcher if directory loading is enabled
if self.load_tools_from_directory:
self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry)

# Initialize tool executor
self.tool_executor = tool_executor or ConcurrentToolExecutor()

# Initialize hooks system
self.hooks = HookRegistry()
if hooks:
for hook in hooks:
self.hooks.add_hook(hook)

# Initialize other components
self.event_loop_metrics = EventLoopMetrics()
self.tool_caller = BidirectionalAgent.ToolCaller(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
import traceback
import uuid

from ....tools._validator import validate_and_prepare_tools
from ....telemetry.metrics import Trace
from ....tools._validator import validate_and_prepare_tools
from ....types._events import ToolResultEvent, ToolStreamEvent
from ....types.content import Message
from ....types.tools import ToolResult, ToolUse
from ..models.bidirectional_model import BidirectionalModelSession

from .. import BidirectionalAgent
from ..models.base_session import BidirectionalModelSession

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, model_session: BidirectionalModelSession, agent: "Bidirection
# Interruption handling (model-agnostic)
self.interrupted = False
self.interruption_lock = asyncio.Lock()

# Tool execution tracking
self.tool_count = 0

Expand Down Expand Up @@ -265,7 +265,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None:
# Basic validation - skip invalid events
if not isinstance(provider_event, dict):
continue

strands_event = provider_event

# Handle interruption detection (provider converts raw patterns to interruptionDetected)
Expand All @@ -291,7 +291,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None:
if strands_event.get("messageStop"):
logger.debug("Message added to history")
session.agent.messages.append(strands_event["messageStop"]["message"])

# Handle user audio transcripts - add to message history
if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user":
user_transcript = strands_event["textOutput"]["text"]
Expand All @@ -311,7 +311,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None:
"""Execute tools concurrently with interruption support.

Background task that manages tool execution without blocking model event
processing or user interaction. Uses proper asyncio cancellation for
processing or user interaction. Uses proper asyncio cancellation for
interruption handling rather than manual state checks.

Args:
Expand All @@ -323,10 +323,10 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None:
tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT)
tool_name = tool_use.get("name")
tool_id = tool_use.get("toolUseId")

session.tool_count += 1
print(f"\nTool #{session.tool_count}: {tool_name}")

logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id)

task_id = str(uuid.uuid4())
Expand Down Expand Up @@ -372,110 +372,96 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None:
logger.debug("Tool execution processor stopped")





async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None:
"""Execute tool using the complete Strands tool execution system.

Uses proper Strands ToolExecutor system with validation, error handling,
and event streaming.

Args:
session: BidirectionalConnection for context.
tool_use: Tool use event to execute.
"""
tool_name = tool_use.get("name")
tool_id = tool_use.get("toolUseId")

logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id)

try:
# Create message structure for validation
# Create message structure for validation
tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]}

# Use Strands validation system
tool_uses: list[ToolUse] = []
tool_results: list[ToolResult] = []
invalid_tool_use_ids: list[str] = []

validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids)

# Filter valid tools
valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids]

if not valid_tool_uses:
logger.warning("No valid tools after validation: %s", tool_name)
return

# Create invocation state for tool execution
invocation_state = {
"agent": session.agent,
"model": session.agent.model,
"messages": session.agent.messages,
"system_prompt": session.agent.system_prompt,
}

# Create cycle trace and span
cycle_trace = Trace("Bidirectional Tool Execution")
cycle_span = None

tool_events = session.agent.tool_executor._execute(
session.agent,
valid_tool_uses,
tool_results,
cycle_trace,
cycle_span,
invocation_state
session.agent, valid_tool_uses, tool_results, cycle_trace, cycle_span, invocation_state
)

# Process tool events and send results to provider
async for tool_event in tool_events:
if isinstance(tool_event, ToolResultEvent):
tool_result = tool_event.tool_result
tool_use_id = tool_result.get("toolUseId")

# Send result through provider-specific session
await session.model_session.send_tool_result(tool_use_id, tool_result)
logger.debug("Tool result sent: %s", tool_use_id)

# Handle streaming events if needed later
elif isinstance(tool_event, ToolStreamEvent):
logger.debug("Tool stream event: %s", tool_event)
pass

# Add tool result message to conversation history
if tool_results:
from ....hooks import MessageAddedEvent

tool_result_message: Message = {
"role": "user",
"content": [{"toolResult": result} for result in tool_results],
}

session.agent.messages.append(tool_result_message)
session.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=session.agent, message=tool_result_message))
logger.debug("Tool result message added to history: %s", tool_name)

logger.debug("Tool execution completed: %s", tool_name)

except asyncio.CancelledError:
logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id)
raise
except Exception as e:
logger.error("Tool execution error: %s - %s", tool_name, str(e))

# Send error result
error_result: ToolResult = {
"toolUseId": tool_id,
"status": "error",
"content": [{"text": f"Error: {str(e)}"}]
}

# Send error result
error_result: ToolResult = {"toolUseId": tool_id, "status": "error", "content": [{"text": f"Error: {str(e)}"}]}
try:
await session.model_session.send_tool_result(tool_id, error_result)
logger.debug("Error result sent: %s", tool_id)
except Exception:
logger.error("Failed to send error result: %s", tool_id)
pass # Session might be closed


Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Bidirectional model interfaces and implementations."""

from .bidirectional_model import BidirectionalModel, BidirectionalModelSession
from .base_model import BidirectionalModel
from .base_session import BidirectionalModelSession
from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession

__all__ = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Unified bidirectional streaming interface.

Single layer combining model and session abstractions for simpler implementation.
"""

from typing import AsyncIterable, Protocol, Union

from ....types.content import Messages
from ....types.tools import ToolResult, ToolSpec
from ..types.bidirectional_streaming import (
AudioInputEvent,
BidirectionalStreamEvent,
ImageInputEvent,
TextInputEvent,
)


class BaseModel(Protocol):
"""Unified interface for bidirectional streaming models.

Combines model configuration and session communication in a single abstraction.
Providers implement this directly without separate model/session classes.
"""

async def connect(
self,
system_prompt: str | None = None,
tools: list[ToolSpec] | None = None,
messages: Messages | None = None,
**kwargs,
) -> None:
"""Establish bidirectional connection with the model."""
...

async def close(self) -> None:
"""Close connection and cleanup resources."""
...

async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]:
"""Receive events from the model in standardized format."""
...

async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None:
"""Send structured content to the model."""
...
Loading
Loading