Skip to content

Commit 9b20e12

Browse files
authored
Merge pull request #5 from mehtarac/tool_executor
feat(tool_executor): Plug tool executor into bidirectional streaming implementation
2 parents 909fc64 + ee12db3 commit 9b20e12

File tree

7 files changed

+501
-105
lines changed

7 files changed

+501
-105
lines changed
Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,37 @@
1-
"""Bidirectional streaming package for real-time audio/text conversations."""
1+
"""Bidirectional streaming package."""
22

3+
# Main components - Primary user interface
4+
from .agent.agent import BidirectionalAgent
5+
6+
# Advanced interfaces (for custom implementations)
7+
from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession
8+
9+
# Model providers - What users need to create models
10+
from .models.novasonic import NovaSonicBidirectionalModel
11+
12+
# Event types - For type hints and event handling
13+
from .types.bidirectional_streaming import (
14+
AudioInputEvent,
15+
AudioOutputEvent,
16+
BidirectionalStreamEvent,
17+
InterruptionDetectedEvent,
18+
TextOutputEvent,
19+
UsageMetricsEvent,
20+
)
21+
22+
__all__ = [
23+
# Main interface
24+
"BidirectionalAgent",
25+
# Model providers
26+
"NovaSonicBidirectionalModel",
27+
# Event types
28+
"AudioInputEvent",
29+
"AudioOutputEvent",
30+
"TextOutputEvent",
31+
"InterruptionDetectedEvent",
32+
"BidirectionalStreamEvent",
33+
"UsageMetricsEvent",
34+
# Model interface
35+
"BidirectionalModel",
36+
"BidirectionalModelSession",
37+
]

src/strands/experimental/bidirectional_streaming/agent/agent.py

Lines changed: 288 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,32 @@
1313
"""
1414

1515
import asyncio
16+
import json
1617
import logging
17-
from typing import AsyncIterable
18+
import random
19+
from concurrent.futures import ThreadPoolExecutor
20+
from typing import Any, AsyncIterable, Callable, Mapping, Optional
1821

22+
from .... import _identifier
23+
from ....hooks import HookProvider, HookRegistry
24+
from ....telemetry.metrics import EventLoopMetrics
1925
from ....tools.executors import ConcurrentToolExecutor
26+
from ....tools.executors._executor import ToolExecutor
2027
from ....tools.registry import ToolRegistry
21-
from ....types.content import Messages
28+
from ....tools.watcher import ToolWatcher
29+
from ....types.content import Message, Messages
30+
from ....types.tools import ToolResult, ToolUse
31+
from ....types.traces import AttributeValue
2232
from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection
2333
from ..models.bidirectional_model import BidirectionalModel
2434
from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent
2535

2636

2737
logger = logging.getLogger(__name__)
2838

39+
_DEFAULT_AGENT_NAME = "Strands Agents"
40+
_DEFAULT_AGENT_ID = "default"
41+
2942

3043
class BidirectionalAgent:
3144
"""Agent for bidirectional streaming conversations.
@@ -34,12 +47,125 @@ class BidirectionalAgent:
3447
sessions. Supports concurrent tool execution and interruption handling.
3548
"""
3649

50+
class ToolCaller:
51+
"""Call tool as a function for bidirectional agent."""
52+
53+
def __init__(self, agent: "BidirectionalAgent") -> None:
54+
"""Initialize tool caller with agent reference."""
55+
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
56+
# agent tools and thus break their execution.
57+
self._agent = agent
58+
59+
def __getattr__(self, name: str) -> Callable[..., Any]:
60+
"""Call tool as a function.
61+
62+
This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
63+
It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing').
64+
65+
Args:
66+
name: The name of the attribute (tool) being accessed.
67+
68+
Returns:
69+
A function that when called will execute the named tool.
70+
71+
Raises:
72+
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
73+
"""
74+
75+
def caller(
76+
user_message_override: Optional[str] = None,
77+
record_direct_tool_call: Optional[bool] = None,
78+
**kwargs: Any,
79+
) -> Any:
80+
"""Call a tool directly by name.
81+
82+
Args:
83+
user_message_override: Optional custom message to record instead of default
84+
record_direct_tool_call: Whether to record direct tool calls in message history.
85+
For bidirectional agents, this is always True to maintain conversation history.
86+
**kwargs: Keyword arguments to pass to the tool.
87+
88+
Returns:
89+
The result returned by the tool.
90+
91+
Raises:
92+
AttributeError: If the tool doesn't exist.
93+
"""
94+
normalized_name = self._find_normalized_tool_name(name)
95+
96+
# Create unique tool ID and set up the tool request
97+
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
98+
tool_use: ToolUse = {
99+
"toolUseId": tool_id,
100+
"name": normalized_name,
101+
"input": kwargs.copy(),
102+
}
103+
tool_results: list[ToolResult] = []
104+
invocation_state = kwargs
105+
106+
async def acall() -> ToolResult:
107+
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
108+
_ = event
109+
110+
return tool_results[0]
111+
112+
def tcall() -> ToolResult:
113+
return asyncio.run(acall())
114+
115+
with ThreadPoolExecutor() as executor:
116+
future = executor.submit(tcall)
117+
tool_result = future.result()
118+
119+
# Always record direct tool calls for bidirectional agents to maintain conversation history
120+
# Use agent's record_direct_tool_call setting if not overridden
121+
if record_direct_tool_call is not None:
122+
should_record_direct_tool_call = record_direct_tool_call
123+
else:
124+
should_record_direct_tool_call = self._agent.record_direct_tool_call
125+
126+
if should_record_direct_tool_call:
127+
# Create a record of this tool execution in the message history
128+
self._agent._record_tool_execution(tool_use, tool_result, user_message_override)
129+
130+
return tool_result
131+
132+
return caller
133+
134+
def _find_normalized_tool_name(self, name: str) -> str:
135+
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
136+
tool_registry = self._agent.tool_registry.registry
137+
138+
if tool_registry.get(name, None):
139+
return name
140+
141+
# If the desired name contains underscores, it might be a placeholder for characters that can't be
142+
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
143+
# all tools that can be represented with the normalized name
144+
if "_" in name:
145+
filtered_tools = [
146+
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
147+
]
148+
149+
# The registry itself defends against similar names, so we can just take the first match
150+
if filtered_tools:
151+
return filtered_tools[0]
152+
153+
raise AttributeError(f"Tool '{name}' not found")
154+
37155
def __init__(
38156
self,
39157
model: BidirectionalModel,
40158
tools: list | None = None,
41159
system_prompt: str | None = None,
42160
messages: Messages | None = None,
161+
record_direct_tool_call: bool = True,
162+
load_tools_from_directory: bool = False,
163+
agent_id: Optional[str] = None,
164+
name: Optional[str] = None,
165+
tool_executor: Optional[ToolExecutor] = None,
166+
hooks: Optional[list[HookProvider]] = None,
167+
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
168+
description: Optional[str] = None,
43169
):
44170
"""Initialize bidirectional agent with required model and optional configuration.
45171
@@ -48,24 +174,177 @@ def __init__(
48174
tools: Optional list of tools available to the model.
49175
system_prompt: Optional system prompt for conversations.
50176
messages: Optional conversation history to initialize with.
177+
record_direct_tool_call: Whether to record direct tool calls in message history.
178+
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
179+
agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios.
180+
name: Name of the Agent.
181+
tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.).
182+
hooks: Hooks to be added to the agent hook registry.
183+
trace_attributes: Custom trace attributes to apply to the agent's trace span.
184+
description: Description of what the Agent does.
51185
"""
52186
self.model = model
53187
self.system_prompt = system_prompt
54188
self.messages = messages or []
55-
56-
# Initialize tool registry using existing Strands infrastructure
189+
190+
# Agent identification
191+
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
192+
self.name = name or _DEFAULT_AGENT_NAME
193+
self.description = description
194+
195+
# Tool execution configuration
196+
self.record_direct_tool_call = record_direct_tool_call
197+
self.load_tools_from_directory = load_tools_from_directory
198+
199+
# Process trace attributes to ensure they're of compatible types
200+
self.trace_attributes: dict[str, AttributeValue] = {}
201+
if trace_attributes:
202+
for k, v in trace_attributes.items():
203+
if isinstance(v, (str, int, float, bool)) or (
204+
isinstance(v, list) and all(isinstance(x, (str, int, float, bool)) for x in v)
205+
):
206+
self.trace_attributes[k] = v
207+
208+
# Initialize tool registry
57209
self.tool_registry = ToolRegistry()
58-
if tools:
210+
211+
if tools is not None:
59212
self.tool_registry.process_tools(tools)
60-
self.tool_registry.initialize_tools()
61-
62-
# Initialize tool executor for concurrent execution
63-
self.tool_executor = ConcurrentToolExecutor()
213+
214+
self.tool_registry.initialize_tools(self.load_tools_from_directory)
215+
216+
# Initialize tool watcher if directory loading is enabled
217+
if self.load_tools_from_directory:
218+
self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry)
219+
220+
# Initialize tool executor
221+
self.tool_executor = tool_executor or ConcurrentToolExecutor()
222+
223+
# Initialize hooks system
224+
self.hooks = HookRegistry()
225+
if hooks:
226+
for hook in hooks:
227+
self.hooks.add_hook(hook)
228+
229+
# Initialize other components
230+
self.event_loop_metrics = EventLoopMetrics()
231+
self.tool_caller = BidirectionalAgent.ToolCaller(self)
64232

65233
# Session management
66234
self._session = None
67235
self._output_queue = asyncio.Queue()
68236

237+
@property
238+
def tool(self) -> ToolCaller:
239+
"""Call tool as a function.
240+
241+
Returns:
242+
Tool caller through which user can invoke tool as a function.
243+
244+
Example:
245+
```
246+
agent = BidirectionalAgent(model=model, tools=[calculator])
247+
agent.tool.calculator(expression="2+2")
248+
```
249+
"""
250+
return self.tool_caller
251+
252+
@property
253+
def tool_names(self) -> list[str]:
254+
"""Get a list of all registered tool names.
255+
256+
Returns:
257+
Names of all tools available to this agent.
258+
"""
259+
all_tools = self.tool_registry.get_all_tools_config()
260+
return list(all_tools.keys())
261+
262+
def _record_tool_execution(
263+
self,
264+
tool: ToolUse,
265+
tool_result: ToolResult,
266+
user_message_override: Optional[str],
267+
) -> None:
268+
"""Record a tool execution in the message history.
269+
270+
Creates a sequence of messages that represent the tool execution:
271+
272+
1. A user message describing the tool call
273+
2. An assistant message with the tool use
274+
3. A user message with the tool result
275+
4. An assistant message acknowledging the tool call
276+
277+
Args:
278+
tool: The tool call information.
279+
tool_result: The result returned by the tool.
280+
user_message_override: Optional custom message to include.
281+
"""
282+
# Filter tool input parameters to only include those defined in tool spec
283+
filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"])
284+
285+
# Create user message describing the tool call
286+
input_parameters = json.dumps(filtered_input, default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")
287+
288+
user_msg_content = [
289+
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")}
290+
]
291+
292+
# Add override message if provided
293+
if user_message_override:
294+
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})
295+
296+
# Create filtered tool use for message history
297+
filtered_tool: ToolUse = {
298+
"toolUseId": tool["toolUseId"],
299+
"name": tool["name"],
300+
"input": filtered_input,
301+
}
302+
303+
# Create the message sequence
304+
user_msg: Message = {
305+
"role": "user",
306+
"content": user_msg_content,
307+
}
308+
tool_use_msg: Message = {
309+
"role": "assistant",
310+
"content": [{"toolUse": filtered_tool}],
311+
}
312+
tool_result_msg: Message = {
313+
"role": "user",
314+
"content": [{"toolResult": tool_result}],
315+
}
316+
assistant_msg: Message = {
317+
"role": "assistant",
318+
"content": [{"text": f"agent.tool.{tool['name']} was called."}],
319+
}
320+
321+
# Add to message history
322+
self.messages.append(user_msg)
323+
self.messages.append(tool_use_msg)
324+
self.messages.append(tool_result_msg)
325+
self.messages.append(assistant_msg)
326+
327+
logger.debug("Direct tool call recorded in message history: %s", tool["name"])
328+
329+
def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]:
330+
"""Filter input parameters to only include those defined in the tool specification.
331+
332+
Args:
333+
tool_name: Name of the tool to get specification for
334+
input_params: Original input parameters
335+
336+
Returns:
337+
Filtered parameters containing only those defined in tool spec
338+
"""
339+
all_tools_config = self.tool_registry.get_all_tools_config()
340+
tool_spec = all_tools_config.get(tool_name)
341+
342+
if not tool_spec or "inputSchema" not in tool_spec:
343+
return input_params.copy()
344+
345+
properties = tool_spec["inputSchema"]["json"]["properties"]
346+
return {k: v for k, v in input_params.items() if k in properties}
347+
69348
async def start(self) -> None:
70349
"""Start a persistent bidirectional conversation session.
71350

0 commit comments

Comments
 (0)