Skip to content

Commit 79e2992

Browse files
committed
feat(mcp): add experimental agent managed connection support
1 parent 1f25512 commit 79e2992

File tree

23 files changed

+1322
-60
lines changed

23 files changed

+1322
-60
lines changed

src/strands/_async.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Private async execution utilities."""
2+
3+
import asyncio
4+
from concurrent.futures import ThreadPoolExecutor
5+
from typing import Awaitable, Callable, TypeVar
6+
7+
T = TypeVar("T")
8+
9+
10+
def run_async(async_func: Callable[[], Awaitable[T]]) -> T:
11+
"""Run an async function in a separate thread to avoid event loop conflicts.
12+
13+
This utility handles the common pattern of running async code from sync contexts
14+
by using ThreadPoolExecutor to isolate the async execution.
15+
16+
Args:
17+
async_func: A callable that returns an awaitable
18+
19+
Returns:
20+
The result of the async function
21+
"""
22+
23+
async def execute_async() -> T:
24+
return await async_func()
25+
26+
def execute() -> T:
27+
return asyncio.run(execute_async())
28+
29+
with ThreadPoolExecutor() as executor:
30+
future = executor.submit(execute)
31+
return future.result()

src/strands/agent/agent.py

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
1010
"""
1111

12-
import asyncio
1312
import json
1413
import logging
1514
import random
16-
from concurrent.futures import ThreadPoolExecutor
1715
from typing import (
1816
Any,
1917
AsyncGenerator,
@@ -31,7 +29,9 @@
3129
from pydantic import BaseModel
3230

3331
from .. import _identifier
32+
from .._async import run_async
3433
from ..event_loop.event_loop import event_loop_cycle
34+
from ..experimental.tools import ToolProvider
3535
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
3636
from ..hooks import (
3737
AfterInvocationEvent,
@@ -160,12 +160,7 @@ async def acall() -> ToolResult:
160160

161161
return tool_results[0]
162162

163-
def tcall() -> ToolResult:
164-
return asyncio.run(acall())
165-
166-
with ThreadPoolExecutor() as executor:
167-
future = executor.submit(tcall)
168-
tool_result = future.result()
163+
tool_result = run_async(acall)
169164

170165
if record_direct_tool_call is not None:
171166
should_record_direct_tool_call = record_direct_tool_call
@@ -208,7 +203,7 @@ def __init__(
208203
self,
209204
model: Union[Model, str, None] = None,
210205
messages: Optional[Messages] = None,
211-
tools: Optional[list[Union[str, dict[str, str], Any]]] = None,
206+
tools: Optional[list[Union[str, dict[str, str], ToolProvider, Any]]] = None,
212207
system_prompt: Optional[str] = None,
213208
callback_handler: Optional[
214209
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
@@ -240,7 +235,8 @@ def __init__(
240235
- File paths (e.g., "/path/to/tool.py")
241236
- Imported Python modules (e.g., from strands_tools import current_time)
242237
- Dictionaries with name/path keys (e.g., {"name": "tool_name", "path": "/path/to/tool.py"})
243-
- Functions decorated with `@strands.tool` decorator.
238+
- Functions decorated with `@strands.tool` decorator
239+
- ToolProvider instances for managed tool collections
244240
245241
If provided, only these tools will be available. If None, all tools will be available.
246242
system_prompt: System prompt to guide model behavior.
@@ -333,6 +329,9 @@ def __init__(
333329
else:
334330
self.state = AgentState()
335331

332+
# Track cleanup state
333+
self._cleanup_called = False
334+
336335
self.tool_caller = Agent.ToolCaller(self)
337336

338337
self.hooks = HookRegistry()
@@ -399,13 +398,7 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
399398
- metrics: Performance metrics from the event loop
400399
- state: The final state of the event loop
401400
"""
402-
403-
def execute() -> AgentResult:
404-
return asyncio.run(self.invoke_async(prompt, **kwargs))
405-
406-
with ThreadPoolExecutor() as executor:
407-
future = executor.submit(execute)
408-
return future.result()
401+
return run_async(lambda: self.invoke_async(prompt, **kwargs))
409402

410403
async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
411404
"""Process a natural language prompt through the agent's event loop.
@@ -459,13 +452,7 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) ->
459452
Raises:
460453
ValueError: If no conversation history or prompt is provided.
461454
"""
462-
463-
def execute() -> T:
464-
return asyncio.run(self.structured_output_async(output_model, prompt))
465-
466-
with ThreadPoolExecutor() as executor:
467-
future = executor.submit(execute)
468-
return future.result()
455+
return run_async(lambda: self.structured_output_async(output_model, prompt))
469456

470457
async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T:
471458
"""This method allows you to get structured output from the agent.
@@ -524,6 +511,69 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
524511
finally:
525512
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
526513

514+
def cleanup(self) -> None:
515+
"""Clean up resources used by the agent.
516+
517+
This method cleans up all tool providers that require explicit cleanup,
518+
such as MCP clients. It should be called when the agent is no longer needed
519+
to ensure proper resource cleanup.
520+
521+
Note: This method uses a "belt and braces" approach with automatic cleanup
522+
through __del__ as a fallback, but explicit cleanup is recommended.
523+
"""
524+
run_async(self.cleanup_async)
525+
526+
async def cleanup_async(self) -> None:
527+
"""Asynchronously clean up resources used by the agent.
528+
529+
This method cleans up all tool providers that require explicit cleanup,
530+
such as MCP clients. It should be called when the agent is no longer needed
531+
to ensure proper resource cleanup.
532+
533+
Note: This method uses a "belt and braces" approach with automatic cleanup
534+
through __del__ as a fallback, but explicit cleanup is recommended.
535+
"""
536+
if self._cleanup_called:
537+
return
538+
539+
logger.debug("agent_id=<%s> | cleaning up agent resources", self.agent_id)
540+
541+
for provider in self.tool_registry.tool_providers:
542+
try:
543+
await provider.cleanup()
544+
logger.debug(
545+
"agent_id=<%s>, provider=<%s> | cleaned up tool provider", self.agent_id, type(provider).__name__
546+
)
547+
except Exception as e:
548+
logger.warning(
549+
"agent_id=<%s>, provider=<%s>, error=<%s> | failed to cleanup tool provider",
550+
self.agent_id,
551+
type(provider).__name__,
552+
e,
553+
)
554+
555+
self._cleanup_called = True
556+
logger.debug("agent_id=<%s> | agent cleanup complete", self.agent_id)
557+
558+
def __del__(self) -> None:
559+
"""Automatic cleanup when agent is garbage collected.
560+
561+
This serves as a fallback cleanup mechanism, but explicit cleanup() is preferred.
562+
"""
563+
try:
564+
if self._cleanup_called or not self.tool_registry.tool_providers:
565+
return
566+
567+
logger.warning(
568+
"agent_id=<%s> | Agent cleanup called via __del__. "
569+
"Consider calling agent.cleanup() explicitly for better resource management.",
570+
self.agent_id,
571+
)
572+
self.cleanup()
573+
except Exception as e:
574+
# Log exceptions during garbage collection cleanup for debugging
575+
logger.debug("agent_id=<%s>, error=<%s> | exception during __del__ cleanup", self.agent_id, e)
576+
527577
async def stream_async(
528578
self,
529579
prompt: AgentInput = None,

src/strands/experimental/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,7 @@
22
33
This module implements experimental features that are subject to change in future revisions without notice.
44
"""
5+
6+
from . import tools
7+
8+
__all__ = ["tools"]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Experimental tools package."""
2+
3+
from .tool_provider import ToolProvider
4+
5+
__all__ = ["ToolProvider"]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Experimental MCP Tool Provider."""
2+
3+
from .mcp_tool_provider import MCPToolProvider, ToolFilters
4+
5+
__all__ = ["MCPToolProvider", "ToolFilters"]
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""MCP Tool Provider implementation."""
2+
3+
import logging
4+
from typing import Callable, Optional, Pattern, Sequence, Union
5+
6+
from typing_extensions import TypedDict
7+
8+
from ....tools.mcp.mcp_agent_tool import MCPAgentTool
9+
from ....tools.mcp.mcp_client import MCPClient
10+
from ....types.exceptions import ToolProviderException
11+
from ....types.tools import AgentTool
12+
from ..tool_provider import ToolProvider
13+
14+
logger = logging.getLogger(__name__)
15+
16+
_ToolFilterCallback = Callable[[AgentTool], bool]
17+
_ToolFilterPattern = Union[str, Pattern[str], _ToolFilterCallback]
18+
19+
20+
class ToolFilters(TypedDict, total=False):
21+
"""Filters for controlling which MCP tools are loaded and available.
22+
23+
Tools are filtered in this order:
24+
1. If 'allowed' is specified, only tools matching these patterns are included
25+
2. Tools matching 'rejected' patterns are then excluded
26+
3. If the result exceeds 'max_tools', it's truncated
27+
"""
28+
29+
allowed: list[_ToolFilterPattern]
30+
rejected: list[_ToolFilterPattern]
31+
max_tools: int
32+
33+
34+
class MCPToolProvider(ToolProvider):
35+
"""Tool provider for MCP clients with managed lifecycle."""
36+
37+
def __init__(
38+
self, *, client: MCPClient, tool_filters: Optional[ToolFilters] = None, disambiguator: Optional[str] = None
39+
) -> None:
40+
"""Initialize with an MCP client.
41+
42+
Args:
43+
client: The MCP client to manage.
44+
tool_filters: Optional filters to apply to tools.
45+
disambiguator: Optional prefix for tool names.
46+
"""
47+
logger.debug(
48+
"tool_filters=<%s>, disambiguator=<%s> | initializing MCPToolProvider", tool_filters, disambiguator
49+
)
50+
self._client = client
51+
self._tool_filters = tool_filters
52+
self._disambiguator = disambiguator
53+
self._tools: Optional[list[MCPAgentTool]] = None # None = not loaded yet, [] = loaded but empty
54+
self._started = False
55+
56+
async def load_tools(self) -> Sequence[AgentTool]:
57+
"""Load and return tools from the MCP client.
58+
59+
Returns:
60+
List of tools from the MCP server.
61+
"""
62+
logger.debug("started=<%s>, cached_tools=<%s> | loading tools", self._started, self._tools is not None)
63+
64+
if not self._started:
65+
try:
66+
logger.debug("starting MCP client")
67+
self._client.start()
68+
self._started = True
69+
logger.debug("MCP client started successfully")
70+
except Exception as e:
71+
logger.error("error=<%s> | failed to start MCP client", e)
72+
raise ToolProviderException(f"Failed to start MCP client: {e}") from e
73+
74+
if self._tools is None:
75+
logger.debug("loading tools from MCP server")
76+
self._tools = []
77+
pagination_token = None
78+
page_count = 0
79+
80+
# Determine max_tools limit for early termination
81+
max_tools_limit = None
82+
if self._tool_filters and "max_tools" in self._tool_filters:
83+
max_tools_limit = self._tool_filters["max_tools"]
84+
logger.debug("max_tools_limit=<%d> | will stop when reached", max_tools_limit)
85+
86+
while True:
87+
logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token)
88+
paginated_tools = self._client.list_tools_sync(pagination_token)
89+
90+
# Process each tool as we get it
91+
for tool in paginated_tools:
92+
# Apply filters
93+
if self._should_include_tool(tool):
94+
# Apply disambiguation if needed
95+
processed_tool = self._apply_disambiguation(tool)
96+
self._tools.append(processed_tool)
97+
98+
# Check if we've reached max_tools limit
99+
if max_tools_limit is not None and len(self._tools) >= max_tools_limit:
100+
logger.debug("max_tools_reached=<%d> | stopping pagination early", len(self._tools))
101+
return self._tools
102+
103+
logger.debug(
104+
"page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page",
105+
page_count,
106+
len(paginated_tools),
107+
len(self._tools),
108+
)
109+
110+
pagination_token = paginated_tools.pagination_token
111+
page_count += 1
112+
113+
if pagination_token is None:
114+
break
115+
116+
logger.debug("final_tools=<%d> | loading complete", len(self._tools))
117+
118+
return self._tools
119+
120+
def _should_include_tool(self, tool: MCPAgentTool) -> bool:
121+
"""Check if a tool should be included based on allowed/rejected filters."""
122+
if not self._tool_filters:
123+
return True
124+
125+
# Apply allowed filter
126+
if "allowed" in self._tool_filters:
127+
if not self._matches_patterns(tool, self._tool_filters["allowed"]):
128+
return False
129+
130+
# Apply rejected filter
131+
if "rejected" in self._tool_filters:
132+
if self._matches_patterns(tool, self._tool_filters["rejected"]):
133+
return False
134+
135+
return True
136+
137+
def _apply_disambiguation(self, tool: MCPAgentTool) -> MCPAgentTool:
138+
"""Apply disambiguation to a single tool if needed."""
139+
if not self._disambiguator:
140+
return tool
141+
142+
# Create new tool with disambiguated agent name but preserve original MCP name
143+
old_name = tool.tool_name
144+
new_agent_name = f"{self._disambiguator}_{tool.mcp_tool.name}"
145+
new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, agent_facing_tool_name=new_agent_name)
146+
logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name)
147+
return new_tool
148+
149+
def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPattern]) -> bool:
150+
"""Check if tool matches any of the given patterns."""
151+
for pattern in patterns:
152+
if callable(pattern):
153+
if pattern(tool):
154+
return True
155+
elif hasattr(pattern, "match") and hasattr(pattern, "pattern"):
156+
if pattern.match(tool.tool_name):
157+
return True
158+
elif isinstance(pattern, str):
159+
if pattern == tool.tool_name:
160+
return True
161+
return False
162+
163+
async def cleanup(self) -> None:
164+
"""Clean up the MCP client connection."""
165+
if not self._started:
166+
return
167+
168+
logger.debug("cleaning up MCP client")
169+
try:
170+
logger.debug("stopping MCP client")
171+
self._client.stop(None, None, None)
172+
logger.debug("MCP client stopped successfully")
173+
except Exception as e:
174+
logger.error("error=<%s> | failed to cleanup MCP client", e)
175+
raise ToolProviderException(f"Failed to cleanup MCP client: {e}") from e
176+
177+
# Only reset state if cleanup succeeded
178+
self._started = False
179+
self._tools = None
180+
logger.debug("MCP client cleanup complete")

0 commit comments

Comments
 (0)