1313"""
1414
1515import asyncio
16+ import json
1617import logging
17- from typing import AsyncIterable
18+ import random
19+ from concurrent .futures import ThreadPoolExecutor
20+ from typing import Any , AsyncIterable , Callable , Mapping , Optional
1821
22+ from .... import _identifier
23+ from ....hooks import HookProvider , HookRegistry
24+ from ....telemetry .metrics import EventLoopMetrics
1925from ....tools .executors import ConcurrentToolExecutor
26+ from ....tools .executors ._executor import ToolExecutor
2027from ....tools .registry import ToolRegistry
21- from ....types .content import Messages
28+ from ....tools .watcher import ToolWatcher
29+ from ....types .content import Message , Messages
30+ from ....types .tools import ToolResult , ToolUse
31+ from ....types .traces import AttributeValue
2232from ..event_loop .bidirectional_event_loop import start_bidirectional_connection , stop_bidirectional_connection
2333from ..models .bidirectional_model import BidirectionalModel
2434from ..types .bidirectional_streaming import AudioInputEvent , BidirectionalStreamEvent
2535
2636
2737logger = logging .getLogger (__name__ )
2838
39+ _DEFAULT_AGENT_NAME = "Strands Agents"
40+ _DEFAULT_AGENT_ID = "default"
41+
2942
3043class BidirectionalAgent :
3144 """Agent for bidirectional streaming conversations.
@@ -34,12 +47,125 @@ class BidirectionalAgent:
3447 sessions. Supports concurrent tool execution and interruption handling.
3548 """
3649
50+ class ToolCaller :
51+ """Call tool as a function for bidirectional agent."""
52+
53+ def __init__ (self , agent : "BidirectionalAgent" ) -> None :
54+ """Initialize tool caller with agent reference."""
55+ # WARNING: Do not add any other member variables or methods as this could result in a name conflict with
56+ # agent tools and thus break their execution.
57+ self ._agent = agent
58+
59+ def __getattr__ (self , name : str ) -> Callable [..., Any ]:
60+ """Call tool as a function.
61+
62+ This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
63+ It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing').
64+
65+ Args:
66+ name: The name of the attribute (tool) being accessed.
67+
68+ Returns:
69+ A function that when called will execute the named tool.
70+
71+ Raises:
72+ AttributeError: If no tool with the given name exists or if multiple tools match the given name.
73+ """
74+
75+ def caller (
76+ user_message_override : Optional [str ] = None ,
77+ record_direct_tool_call : Optional [bool ] = None ,
78+ ** kwargs : Any ,
79+ ) -> Any :
80+ """Call a tool directly by name.
81+
82+ Args:
83+ user_message_override: Optional custom message to record instead of default
84+ record_direct_tool_call: Whether to record direct tool calls in message history.
85+ For bidirectional agents, this is always True to maintain conversation history.
86+ **kwargs: Keyword arguments to pass to the tool.
87+
88+ Returns:
89+ The result returned by the tool.
90+
91+ Raises:
92+ AttributeError: If the tool doesn't exist.
93+ """
94+ normalized_name = self ._find_normalized_tool_name (name )
95+
96+ # Create unique tool ID and set up the tool request
97+ tool_id = f"tooluse_{ name } _{ random .randint (100000000 , 999999999 )} "
98+ tool_use : ToolUse = {
99+ "toolUseId" : tool_id ,
100+ "name" : normalized_name ,
101+ "input" : kwargs .copy (),
102+ }
103+ tool_results : list [ToolResult ] = []
104+ invocation_state = kwargs
105+
106+ async def acall () -> ToolResult :
107+ async for event in ToolExecutor ._stream (self ._agent , tool_use , tool_results , invocation_state ):
108+ _ = event
109+
110+ return tool_results [0 ]
111+
112+ def tcall () -> ToolResult :
113+ return asyncio .run (acall ())
114+
115+ with ThreadPoolExecutor () as executor :
116+ future = executor .submit (tcall )
117+ tool_result = future .result ()
118+
119+ # Always record direct tool calls for bidirectional agents to maintain conversation history
120+ # Use agent's record_direct_tool_call setting if not overridden
121+ if record_direct_tool_call is not None :
122+ should_record_direct_tool_call = record_direct_tool_call
123+ else :
124+ should_record_direct_tool_call = self ._agent .record_direct_tool_call
125+
126+ if should_record_direct_tool_call :
127+ # Create a record of this tool execution in the message history
128+ self ._agent ._record_tool_execution (tool_use , tool_result , user_message_override )
129+
130+ return tool_result
131+
132+ return caller
133+
134+ def _find_normalized_tool_name (self , name : str ) -> str :
135+ """Lookup the tool represented by name, replacing characters with underscores as necessary."""
136+ tool_registry = self ._agent .tool_registry .registry
137+
138+ if tool_registry .get (name , None ):
139+ return name
140+
141+ # If the desired name contains underscores, it might be a placeholder for characters that can't be
142+ # represented as python identifiers but are valid as tool names, such as dashes. In that case, find
143+ # all tools that can be represented with the normalized name
144+ if "_" in name :
145+ filtered_tools = [
146+ tool_name for (tool_name , tool ) in tool_registry .items () if tool_name .replace ("-" , "_" ) == name
147+ ]
148+
149+ # The registry itself defends against similar names, so we can just take the first match
150+ if filtered_tools :
151+ return filtered_tools [0 ]
152+
153+ raise AttributeError (f"Tool '{ name } ' not found" )
154+
37155 def __init__ (
38156 self ,
39157 model : BidirectionalModel ,
40158 tools : list | None = None ,
41159 system_prompt : str | None = None ,
42160 messages : Messages | None = None ,
161+ record_direct_tool_call : bool = True ,
162+ load_tools_from_directory : bool = False ,
163+ agent_id : Optional [str ] = None ,
164+ name : Optional [str ] = None ,
165+ tool_executor : Optional [ToolExecutor ] = None ,
166+ hooks : Optional [list [HookProvider ]] = None ,
167+ trace_attributes : Optional [Mapping [str , AttributeValue ]] = None ,
168+ description : Optional [str ] = None ,
43169 ):
44170 """Initialize bidirectional agent with required model and optional configuration.
45171
@@ -48,24 +174,177 @@ def __init__(
48174 tools: Optional list of tools available to the model.
49175 system_prompt: Optional system prompt for conversations.
50176 messages: Optional conversation history to initialize with.
177+ record_direct_tool_call: Whether to record direct tool calls in message history.
178+ load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
179+ agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios.
180+ name: Name of the Agent.
181+ tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.).
182+ hooks: Hooks to be added to the agent hook registry.
183+ trace_attributes: Custom trace attributes to apply to the agent's trace span.
184+ description: Description of what the Agent does.
51185 """
52186 self .model = model
53187 self .system_prompt = system_prompt
54188 self .messages = messages or []
55-
56- # Initialize tool registry using existing Strands infrastructure
189+
190+ # Agent identification
191+ self .agent_id = _identifier .validate (agent_id or _DEFAULT_AGENT_ID , _identifier .Identifier .AGENT )
192+ self .name = name or _DEFAULT_AGENT_NAME
193+ self .description = description
194+
195+ # Tool execution configuration
196+ self .record_direct_tool_call = record_direct_tool_call
197+ self .load_tools_from_directory = load_tools_from_directory
198+
199+ # Process trace attributes to ensure they're of compatible types
200+ self .trace_attributes : dict [str , AttributeValue ] = {}
201+ if trace_attributes :
202+ for k , v in trace_attributes .items ():
203+ if isinstance (v , (str , int , float , bool )) or (
204+ isinstance (v , list ) and all (isinstance (x , (str , int , float , bool )) for x in v )
205+ ):
206+ self .trace_attributes [k ] = v
207+
208+ # Initialize tool registry
57209 self .tool_registry = ToolRegistry ()
58- if tools :
210+
211+ if tools is not None :
59212 self .tool_registry .process_tools (tools )
60- self .tool_registry .initialize_tools ()
61-
62- # Initialize tool executor for concurrent execution
63- self .tool_executor = ConcurrentToolExecutor ()
213+
214+ self .tool_registry .initialize_tools (self .load_tools_from_directory )
215+
216+ # Initialize tool watcher if directory loading is enabled
217+ if self .load_tools_from_directory :
218+ self .tool_watcher = ToolWatcher (tool_registry = self .tool_registry )
219+
220+ # Initialize tool executor
221+ self .tool_executor = tool_executor or ConcurrentToolExecutor ()
222+
223+ # Initialize hooks system
224+ self .hooks = HookRegistry ()
225+ if hooks :
226+ for hook in hooks :
227+ self .hooks .add_hook (hook )
228+
229+ # Initialize other components
230+ self .event_loop_metrics = EventLoopMetrics ()
231+ self .tool_caller = BidirectionalAgent .ToolCaller (self )
64232
65233 # Session management
66234 self ._session = None
67235 self ._output_queue = asyncio .Queue ()
68236
237+ @property
238+ def tool (self ) -> ToolCaller :
239+ """Call tool as a function.
240+
241+ Returns:
242+ Tool caller through which user can invoke tool as a function.
243+
244+ Example:
245+ ```
246+ agent = BidirectionalAgent(model=model, tools=[calculator])
247+ agent.tool.calculator(expression="2+2")
248+ ```
249+ """
250+ return self .tool_caller
251+
252+ @property
253+ def tool_names (self ) -> list [str ]:
254+ """Get a list of all registered tool names.
255+
256+ Returns:
257+ Names of all tools available to this agent.
258+ """
259+ all_tools = self .tool_registry .get_all_tools_config ()
260+ return list (all_tools .keys ())
261+
262+ def _record_tool_execution (
263+ self ,
264+ tool : ToolUse ,
265+ tool_result : ToolResult ,
266+ user_message_override : Optional [str ],
267+ ) -> None :
268+ """Record a tool execution in the message history.
269+
270+ Creates a sequence of messages that represent the tool execution:
271+
272+ 1. A user message describing the tool call
273+ 2. An assistant message with the tool use
274+ 3. A user message with the tool result
275+ 4. An assistant message acknowledging the tool call
276+
277+ Args:
278+ tool: The tool call information.
279+ tool_result: The result returned by the tool.
280+ user_message_override: Optional custom message to include.
281+ """
282+ # Filter tool input parameters to only include those defined in tool spec
283+ filtered_input = self ._filter_tool_parameters_for_recording (tool ["name" ], tool ["input" ])
284+
285+ # Create user message describing the tool call
286+ input_parameters = json .dumps (filtered_input , default = lambda o : f"<<non-serializable: { type (o ).__qualname__ } >>" )
287+
288+ user_msg_content = [
289+ {"text" : (f"agent.tool.{ tool ['name' ]} direct tool call.\n Input parameters: { input_parameters } \n " )}
290+ ]
291+
292+ # Add override message if provided
293+ if user_message_override :
294+ user_msg_content .insert (0 , {"text" : f"{ user_message_override } \n " })
295+
296+ # Create filtered tool use for message history
297+ filtered_tool : ToolUse = {
298+ "toolUseId" : tool ["toolUseId" ],
299+ "name" : tool ["name" ],
300+ "input" : filtered_input ,
301+ }
302+
303+ # Create the message sequence
304+ user_msg : Message = {
305+ "role" : "user" ,
306+ "content" : user_msg_content ,
307+ }
308+ tool_use_msg : Message = {
309+ "role" : "assistant" ,
310+ "content" : [{"toolUse" : filtered_tool }],
311+ }
312+ tool_result_msg : Message = {
313+ "role" : "user" ,
314+ "content" : [{"toolResult" : tool_result }],
315+ }
316+ assistant_msg : Message = {
317+ "role" : "assistant" ,
318+ "content" : [{"text" : f"agent.tool.{ tool ['name' ]} was called." }],
319+ }
320+
321+ # Add to message history
322+ self .messages .append (user_msg )
323+ self .messages .append (tool_use_msg )
324+ self .messages .append (tool_result_msg )
325+ self .messages .append (assistant_msg )
326+
327+ logger .debug ("Direct tool call recorded in message history: %s" , tool ["name" ])
328+
329+ def _filter_tool_parameters_for_recording (self , tool_name : str , input_params : dict [str , Any ]) -> dict [str , Any ]:
330+ """Filter input parameters to only include those defined in the tool specification.
331+
332+ Args:
333+ tool_name: Name of the tool to get specification for
334+ input_params: Original input parameters
335+
336+ Returns:
337+ Filtered parameters containing only those defined in tool spec
338+ """
339+ all_tools_config = self .tool_registry .get_all_tools_config ()
340+ tool_spec = all_tools_config .get (tool_name )
341+
342+ if not tool_spec or "inputSchema" not in tool_spec :
343+ return input_params .copy ()
344+
345+ properties = tool_spec ["inputSchema" ]["json" ]["properties" ]
346+ return {k : v for k , v in input_params .items () if k in properties }
347+
69348 async def start (self ) -> None :
70349 """Start a persistent bidirectional conversation session.
71350
0 commit comments