Skip to content

Commit 3446938

Browse files
feat(mcp): add experimental agent managed connection via ToolProvider (#895)
1 parent 999e654 commit 3446938

File tree

24 files changed

+1925
-72
lines changed

24 files changed

+1925
-72
lines changed

.codecov.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
coverage:
2+
status:
3+
project:
4+
default:
5+
target: 90% # overall coverage threshold
6+
patch:
7+
default:
8+
target: 90% # patch coverage threshold
9+
base: auto
10+
# Only post patch coverage on decreases
11+
only_pulls: true

src/strands/_async.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Private async execution utilities."""
2+
3+
import asyncio
4+
from concurrent.futures import ThreadPoolExecutor
5+
from typing import Awaitable, Callable, TypeVar
6+
7+
T = TypeVar("T")
8+
9+
10+
def run_async(async_func: Callable[[], Awaitable[T]]) -> T:
11+
"""Run an async function in a separate thread to avoid event loop conflicts.
12+
13+
This utility handles the common pattern of running async code from sync contexts
14+
by using ThreadPoolExecutor to isolate the async execution.
15+
16+
Args:
17+
async_func: A callable that returns an awaitable
18+
19+
Returns:
20+
The result of the async function
21+
"""
22+
23+
async def execute_async() -> T:
24+
return await async_func()
25+
26+
def execute() -> T:
27+
return asyncio.run(execute_async())
28+
29+
with ThreadPoolExecutor() as executor:
30+
future = executor.submit(execute)
31+
return future.result()

src/strands/agent/agent.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
1010
"""
1111

12-
import asyncio
1312
import json
1413
import logging
1514
import random
1615
import warnings
17-
from concurrent.futures import ThreadPoolExecutor
1816
from typing import (
17+
TYPE_CHECKING,
1918
Any,
2019
AsyncGenerator,
2120
AsyncIterator,
@@ -32,7 +31,11 @@
3231
from pydantic import BaseModel
3332

3433
from .. import _identifier
34+
from .._async import run_async
3535
from ..event_loop.event_loop import event_loop_cycle
36+
37+
if TYPE_CHECKING:
38+
from ..experimental.tools import ToolProvider
3639
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
3740
from ..hooks import (
3841
AfterInvocationEvent,
@@ -167,12 +170,7 @@ async def acall() -> ToolResult:
167170

168171
return tool_results[0]
169172

170-
def tcall() -> ToolResult:
171-
return asyncio.run(acall())
172-
173-
with ThreadPoolExecutor() as executor:
174-
future = executor.submit(tcall)
175-
tool_result = future.result()
173+
tool_result = run_async(acall)
176174

177175
if record_direct_tool_call is not None:
178176
should_record_direct_tool_call = record_direct_tool_call
@@ -215,7 +213,7 @@ def __init__(
215213
self,
216214
model: Union[Model, str, None] = None,
217215
messages: Optional[Messages] = None,
218-
tools: Optional[list[Union[str, dict[str, str], Any]]] = None,
216+
tools: Optional[list[Union[str, dict[str, str], "ToolProvider", Any]]] = None,
219217
system_prompt: Optional[str] = None,
220218
structured_output_model: Optional[Type[BaseModel]] = None,
221219
callback_handler: Optional[
@@ -248,6 +246,7 @@ def __init__(
248246
- File paths (e.g., "/path/to/tool.py")
249247
- Imported Python modules (e.g., from strands_tools import current_time)
250248
- Dictionaries with name/path keys (e.g., {"name": "tool_name", "path": "/path/to/tool.py"})
249+
- ToolProvider instances for managed tool collections
251250
- Functions decorated with `@strands.tool` decorator.
252251
253252
If provided, only these tools will be available. If None, all tools will be available.
@@ -423,17 +422,11 @@ def __call__(
423422
- state: The final state of the event loop
424423
- structured_output: Parsed structured output when structured_output_model was specified
425424
"""
426-
427-
def execute() -> AgentResult:
428-
return asyncio.run(
429-
self.invoke_async(
430-
prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs
431-
)
425+
return run_async(
426+
lambda: self.invoke_async(
427+
prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs
432428
)
433-
434-
with ThreadPoolExecutor() as executor:
435-
future = executor.submit(execute)
436-
return future.result()
429+
)
437430

438431
async def invoke_async(
439432
self,
@@ -506,12 +499,7 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) ->
506499
stacklevel=2,
507500
)
508501

509-
def execute() -> T:
510-
return asyncio.run(self.structured_output_async(output_model, prompt))
511-
512-
with ThreadPoolExecutor() as executor:
513-
future = executor.submit(execute)
514-
return future.result()
502+
return run_async(lambda: self.structured_output_async(output_model, prompt))
515503

516504
async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T:
517505
"""This method allows you to get structured output from the agent.
@@ -529,6 +517,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
529517
530518
Raises:
531519
ValueError: If no conversation history or prompt is provided.
520+
-
532521
"""
533522
if self._interrupt_state.activated:
534523
raise RuntimeError("cannot call structured output during interrupt")
@@ -583,6 +572,25 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
583572
finally:
584573
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
585574

575+
def cleanup(self) -> None:
576+
"""Clean up resources used by the agent.
577+
578+
This method cleans up all tool providers that require explicit cleanup,
579+
such as MCP clients. It should be called when the agent is no longer needed
580+
to ensure proper resource cleanup.
581+
582+
Note: This method uses a "belt and braces" approach with automatic cleanup
583+
through finalizers as a fallback, but explicit cleanup is recommended.
584+
"""
585+
self.tool_registry.cleanup()
586+
587+
def __del__(self) -> None:
588+
"""Clean up resources when agent is garbage collected."""
589+
# __del__ is called even when an exception is thrown in the constructor,
590+
# so there is no guarantee tool_registry was set..
591+
if hasattr(self, "tool_registry"):
592+
self.tool_registry.cleanup()
593+
586594
async def stream_async(
587595
self,
588596
prompt: AgentInput = None,

src/strands/experimental/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
This module implements experimental features that are subject to change in future revisions without notice.
44
"""
55

6+
from . import tools
67
from .agent_config import config_to_agent
78

8-
__all__ = ["config_to_agent"]
9+
__all__ = ["config_to_agent", "tools"]

src/strands/experimental/agent_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import jsonschema
1919
from jsonschema import ValidationError
2020

21-
from ..agent import Agent
22-
2321
# JSON Schema for agent configuration
2422
AGENT_CONFIG_SCHEMA = {
2523
"$schema": "http://json-schema.org/draft-07/schema#",
@@ -53,7 +51,7 @@
5351
_VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA)
5452

5553

56-
def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Agent:
54+
def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Any:
5755
"""Create an Agent from a configuration file or dictionary.
5856
5957
This function supports tools that can be loaded declaratively (file paths, module names,
@@ -134,5 +132,8 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A
134132
# Override with any additional kwargs provided
135133
agent_kwargs.update(kwargs)
136134

135+
# Import Agent at runtime to avoid circular imports
136+
from ..agent import Agent
137+
137138
# Create and return Agent
138139
return Agent(**agent_kwargs)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Experimental tools package."""
2+
3+
from .tool_provider import ToolProvider
4+
5+
__all__ = ["ToolProvider"]
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Tool provider interface."""
2+
3+
from abc import ABC, abstractmethod
4+
from typing import TYPE_CHECKING, Any, Sequence
5+
6+
if TYPE_CHECKING:
7+
from ...types.tools import AgentTool
8+
9+
10+
class ToolProvider(ABC):
11+
"""Interface for providing tools with lifecycle management.
12+
13+
Provides a way to load a collection of tools and clean them up
14+
when done, with lifecycle managed by the agent.
15+
"""
16+
17+
@abstractmethod
18+
async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]:
19+
"""Load and return the tools in this provider.
20+
21+
Args:
22+
**kwargs: Additional arguments for future compatibility.
23+
24+
Returns:
25+
List of tools that are ready to use.
26+
"""
27+
pass
28+
29+
@abstractmethod
30+
def add_consumer(self, consumer_id: Any, **kwargs: Any) -> None:
31+
"""Add a consumer to this tool provider.
32+
33+
Args:
34+
consumer_id: Unique identifier for the consumer.
35+
**kwargs: Additional arguments for future compatibility.
36+
"""
37+
pass
38+
39+
@abstractmethod
40+
def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None:
41+
"""Remove a consumer from this tool provider.
42+
43+
This method must be idempotent - calling it multiple times with the same ID
44+
should have no additional effect after the first call.
45+
46+
Provider may clean up resources when no consumers remain.
47+
48+
Args:
49+
consumer_id: Unique identifier for the consumer.
50+
**kwargs: Additional arguments for future compatibility.
51+
"""
52+
pass

src/strands/multiagent/base.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
Provides minimal foundation for multi-agent patterns (Swarm, Graph).
44
"""
55

6-
import asyncio
76
import logging
87
import warnings
98
from abc import ABC, abstractmethod
10-
from concurrent.futures import ThreadPoolExecutor
119
from dataclasses import dataclass, field
1210
from enum import Enum
1311
from typing import Any, Union
1412

13+
from .._async import run_async
1514
from ..agent import AgentResult
1615
from ..types.content import ContentBlock
1716
from ..types.event_loop import Metrics, Usage
@@ -199,12 +198,7 @@ def __call__(
199198
invocation_state.update(kwargs)
200199
warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2)
201200

202-
def execute() -> MultiAgentResult:
203-
return asyncio.run(self.invoke_async(task, invocation_state))
204-
205-
with ThreadPoolExecutor() as executor:
206-
future = executor.submit(execute)
207-
return future.result()
201+
return run_async(lambda: self.invoke_async(task, invocation_state))
208202

209203
def serialize_state(self) -> dict[str, Any]:
210204
"""Return a JSON-serializable snapshot of the orchestrator state."""

src/strands/multiagent/graph.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
import copy
1919
import logging
2020
import time
21-
from concurrent.futures import ThreadPoolExecutor
2221
from dataclasses import dataclass, field
2322
from typing import Any, Callable, Optional, Tuple
2423

2524
from opentelemetry import trace as trace_api
2625

26+
from .._async import run_async
2727
from ..agent import Agent
2828
from ..agent.state import AgentState
2929
from ..telemetry import get_tracer
@@ -399,12 +399,7 @@ def __call__(
399399
if invocation_state is None:
400400
invocation_state = {}
401401

402-
def execute() -> GraphResult:
403-
return asyncio.run(self.invoke_async(task, invocation_state))
404-
405-
with ThreadPoolExecutor() as executor:
406-
future = executor.submit(execute)
407-
return future.result()
402+
return run_async(lambda: self.invoke_async(task, invocation_state))
408403

409404
async def invoke_async(
410405
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any

src/strands/multiagent/swarm.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
import json
1818
import logging
1919
import time
20-
from concurrent.futures import ThreadPoolExecutor
2120
from dataclasses import dataclass, field
2221
from typing import Any, Callable, Tuple
2322

2423
from opentelemetry import trace as trace_api
2524

26-
from ..agent import Agent, AgentResult
25+
from .._async import run_async
26+
from ..agent import Agent
27+
from ..agent.agent_result import AgentResult
2728
from ..agent.state import AgentState
2829
from ..telemetry import get_tracer
2930
from ..tools.decorator import tool
@@ -254,12 +255,7 @@ def __call__(
254255
if invocation_state is None:
255256
invocation_state = {}
256257

257-
def execute() -> SwarmResult:
258-
return asyncio.run(self.invoke_async(task, invocation_state))
259-
260-
with ThreadPoolExecutor() as executor:
261-
future = executor.submit(execute)
262-
return future.result()
258+
return run_async(lambda: self.invoke_async(task, invocation_state))
263259

264260
async def invoke_async(
265261
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any

0 commit comments

Comments
 (0)