From cce1dcd859444ae7ebded186c84d2cf097fd337f Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 17 Sep 2025 21:18:02 -0400 Subject: [PATCH 01/35] feat(mcp): add experimental agent managed connection support --- src/strands/_async.py | 31 ++ src/strands/agent/agent.py | 98 +++-- src/strands/experimental/__init__.py | 3 +- src/strands/experimental/tools/__init__.py | 5 + .../experimental/tools/mcp/__init__.py | 5 + .../tools/mcp/mcp_tool_provider.py | 180 +++++++++ .../experimental/tools/tool_provider.py | 32 ++ src/strands/multiagent/base.py | 10 +- src/strands/multiagent/graph.py | 9 +- src/strands/multiagent/swarm.py | 12 +- src/strands/tools/mcp/mcp_agent_tool.py | 13 +- src/strands/tools/registry.py | 13 + src/strands/types/exceptions.py | 6 + tests/strands/agent/test_agent.py | 178 ++++++++- tests/strands/experimental/tools/__init__.py | 0 .../experimental/tools/mcp/__init__.py | 0 .../tools/mcp/test_mcp_tool_provider.py | 375 ++++++++++++++++++ tests/strands/test_async.py | 25 ++ .../tools/test_registry_tool_provider.py | 202 ++++++++++ tests_integ/test_mcp_tool_provider.py | 171 ++++++++ 20 files changed, 1314 insertions(+), 54 deletions(-) create mode 100644 src/strands/_async.py create mode 100644 src/strands/experimental/tools/__init__.py create mode 100644 src/strands/experimental/tools/mcp/__init__.py create mode 100644 src/strands/experimental/tools/mcp/mcp_tool_provider.py create mode 100644 src/strands/experimental/tools/tool_provider.py create mode 100644 tests/strands/experimental/tools/__init__.py create mode 100644 tests/strands/experimental/tools/mcp/__init__.py create mode 100644 tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py create mode 100644 tests/strands/test_async.py create mode 100644 tests/strands/tools/test_registry_tool_provider.py create mode 100644 tests_integ/test_mcp_tool_provider.py diff --git a/src/strands/_async.py b/src/strands/_async.py new file mode 100644 index 000000000..976487c37 --- /dev/null +++ b/src/strands/_async.py @@ -0,0 +1,31 @@ +"""Private async execution utilities.""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import Awaitable, Callable, TypeVar + +T = TypeVar("T") + + +def run_async(async_func: Callable[[], Awaitable[T]]) -> T: + """Run an async function in a separate thread to avoid event loop conflicts. + + This utility handles the common pattern of running async code from sync contexts + by using ThreadPoolExecutor to isolate the async execution. + + Args: + async_func: A callable that returns an awaitable + + Returns: + The result of the async function + """ + + async def execute_async() -> T: + return await async_func() + + def execute() -> T: + return asyncio.run(execute_async()) + + with ThreadPoolExecutor() as executor: + future = executor.submit(execute) + return future.result() diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f963f14e7..1eab8b639 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,12 +9,10 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ -import asyncio import json import logging import random import warnings -from concurrent.futures import ThreadPoolExecutor from typing import ( Any, AsyncGenerator, @@ -32,7 +30,9 @@ from pydantic import BaseModel from .. import _identifier +from .._async import run_async from ..event_loop.event_loop import event_loop_cycle +from ..experimental.tools import ToolProvider from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, @@ -166,12 +166,7 @@ async def acall() -> ToolResult: return tool_results[0] - def tcall() -> ToolResult: - return asyncio.run(acall()) - - with ThreadPoolExecutor() as executor: - future = executor.submit(tcall) - tool_result = future.result() + tool_result = run_async(acall) if record_direct_tool_call is not None: should_record_direct_tool_call = record_direct_tool_call @@ -214,7 +209,7 @@ def __init__( self, model: Union[Model, str, None] = None, messages: Optional[Messages] = None, - tools: Optional[list[Union[str, dict[str, str], Any]]] = None, + tools: Optional[list[Union[str, dict[str, str], ToolProvider, Any]]] = None, system_prompt: Optional[str] = None, callback_handler: Optional[ Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] @@ -246,7 +241,8 @@ def __init__( - File paths (e.g., "/path/to/tool.py") - Imported Python modules (e.g., from strands_tools import current_time) - Dictionaries with name/path keys (e.g., {"name": "tool_name", "path": "/path/to/tool.py"}) - - Functions decorated with `@strands.tool` decorator. + - Functions decorated with `@strands.tool` decorator + - ToolProvider instances for managed tool collections If provided, only these tools will be available. If None, all tools will be available. system_prompt: System prompt to guide model behavior. @@ -339,6 +335,9 @@ def __init__( else: self.state = AgentState() + # Track cleanup state + self._cleanup_called = False + self.tool_caller = Agent.ToolCaller(self) self.hooks = HookRegistry() @@ -410,13 +409,7 @@ def __call__( - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - - def execute() -> AgentResult: - return asyncio.run(self.invoke_async(prompt, invocation_state=invocation_state, **kwargs)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(prompt, invocation_state=invocation_state, **kwargs)) async def invoke_async( self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -473,13 +466,7 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> Raises: ValueError: If no conversation history or prompt is provided. """ - - def execute() -> T: - return asyncio.run(self.structured_output_async(output_model, prompt)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.structured_output_async(output_model, prompt)) async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. @@ -544,6 +531,69 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu finally: self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + def cleanup(self) -> None: + """Clean up resources used by the agent. + + This method cleans up all tool providers that require explicit cleanup, + such as MCP clients. It should be called when the agent is no longer needed + to ensure proper resource cleanup. + + Note: This method uses a "belt and braces" approach with automatic cleanup + through __del__ as a fallback, but explicit cleanup is recommended. + """ + run_async(self.cleanup_async) + + async def cleanup_async(self) -> None: + """Asynchronously clean up resources used by the agent. + + This method cleans up all tool providers that require explicit cleanup, + such as MCP clients. It should be called when the agent is no longer needed + to ensure proper resource cleanup. + + Note: This method uses a "belt and braces" approach with automatic cleanup + through __del__ as a fallback, but explicit cleanup is recommended. + """ + if self._cleanup_called: + return + + logger.debug("agent_id=<%s> | cleaning up agent resources", self.agent_id) + + for provider in self.tool_registry.tool_providers: + try: + await provider.cleanup() + logger.debug( + "agent_id=<%s>, provider=<%s> | cleaned up tool provider", self.agent_id, type(provider).__name__ + ) + except Exception as e: + logger.warning( + "agent_id=<%s>, provider=<%s>, error=<%s> | failed to cleanup tool provider", + self.agent_id, + type(provider).__name__, + e, + ) + + self._cleanup_called = True + logger.debug("agent_id=<%s> | agent cleanup complete", self.agent_id) + + def __del__(self) -> None: + """Automatic cleanup when agent is garbage collected. + + This serves as a fallback cleanup mechanism, but explicit cleanup() is preferred. + """ + try: + if self._cleanup_called or not self.tool_registry.tool_providers: + return + + logger.warning( + "agent_id=<%s> | Agent cleanup called via __del__. " + "Consider calling agent.cleanup() explicitly for better resource management.", + self.agent_id, + ) + self.cleanup() + except Exception as e: + # Log exceptions during garbage collection cleanup for debugging + logger.debug("agent_id=<%s>, error=<%s> | exception during __del__ cleanup", self.agent_id, e) + async def stream_async( self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> AsyncIterator[Any]: diff --git a/src/strands/experimental/__init__.py b/src/strands/experimental/__init__.py index 86618c153..188c80c69 100644 --- a/src/strands/experimental/__init__.py +++ b/src/strands/experimental/__init__.py @@ -3,6 +3,7 @@ This module implements experimental features that are subject to change in future revisions without notice. """ +from . import tools from .agent_config import config_to_agent -__all__ = ["config_to_agent"] +__all__ = ["config_to_agent", "tools"] diff --git a/src/strands/experimental/tools/__init__.py b/src/strands/experimental/tools/__init__.py new file mode 100644 index 000000000..ad693f8ac --- /dev/null +++ b/src/strands/experimental/tools/__init__.py @@ -0,0 +1,5 @@ +"""Experimental tools package.""" + +from .tool_provider import ToolProvider + +__all__ = ["ToolProvider"] diff --git a/src/strands/experimental/tools/mcp/__init__.py b/src/strands/experimental/tools/mcp/__init__.py new file mode 100644 index 000000000..ee1ccc542 --- /dev/null +++ b/src/strands/experimental/tools/mcp/__init__.py @@ -0,0 +1,5 @@ +"""Experimental MCP Tool Provider.""" + +from .mcp_tool_provider import MCPToolProvider, ToolFilters + +__all__ = ["MCPToolProvider", "ToolFilters"] diff --git a/src/strands/experimental/tools/mcp/mcp_tool_provider.py b/src/strands/experimental/tools/mcp/mcp_tool_provider.py new file mode 100644 index 000000000..610d642bc --- /dev/null +++ b/src/strands/experimental/tools/mcp/mcp_tool_provider.py @@ -0,0 +1,180 @@ +"""MCP Tool Provider implementation.""" + +import logging +from typing import Callable, Optional, Pattern, Sequence, Union + +from typing_extensions import TypedDict + +from ....tools.mcp.mcp_agent_tool import MCPAgentTool +from ....tools.mcp.mcp_client import MCPClient +from ....types.exceptions import ToolProviderException +from ....types.tools import AgentTool +from ..tool_provider import ToolProvider + +logger = logging.getLogger(__name__) + +_ToolFilterCallback = Callable[[AgentTool], bool] +_ToolFilterPattern = Union[str, Pattern[str], _ToolFilterCallback] + + +class ToolFilters(TypedDict, total=False): + """Filters for controlling which MCP tools are loaded and available. + + Tools are filtered in this order: + 1. If 'allowed' is specified, only tools matching these patterns are included + 2. Tools matching 'rejected' patterns are then excluded + 3. If the result exceeds 'max_tools', it's truncated + """ + + allowed: list[_ToolFilterPattern] + rejected: list[_ToolFilterPattern] + max_tools: int + + +class MCPToolProvider(ToolProvider): + """Tool provider for MCP clients with managed lifecycle.""" + + def __init__( + self, *, client: MCPClient, tool_filters: Optional[ToolFilters] = None, disambiguator: Optional[str] = None + ) -> None: + """Initialize with an MCP client. + + Args: + client: The MCP client to manage. + tool_filters: Optional filters to apply to tools. + disambiguator: Optional prefix for tool names. + """ + logger.debug( + "tool_filters=<%s>, disambiguator=<%s> | initializing MCPToolProvider", tool_filters, disambiguator + ) + self._client = client + self._tool_filters = tool_filters + self._disambiguator = disambiguator + self._tools: Optional[list[MCPAgentTool]] = None # None = not loaded yet, [] = loaded but empty + self._started = False + + async def load_tools(self) -> Sequence[AgentTool]: + """Load and return tools from the MCP client. + + Returns: + List of tools from the MCP server. + """ + logger.debug("started=<%s>, cached_tools=<%s> | loading tools", self._started, self._tools is not None) + + if not self._started: + try: + logger.debug("starting MCP client") + self._client.start() + self._started = True + logger.debug("MCP client started successfully") + except Exception as e: + logger.error("error=<%s> | failed to start MCP client", e) + raise ToolProviderException(f"Failed to start MCP client: {e}") from e + + if self._tools is None: + logger.debug("loading tools from MCP server") + self._tools = [] + pagination_token = None + page_count = 0 + + # Determine max_tools limit for early termination + max_tools_limit = None + if self._tool_filters and "max_tools" in self._tool_filters: + max_tools_limit = self._tool_filters["max_tools"] + logger.debug("max_tools_limit=<%d> | will stop when reached", max_tools_limit) + + while True: + logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) + paginated_tools = self._client.list_tools_sync(pagination_token) + + # Process each tool as we get it + for tool in paginated_tools: + # Apply filters + if self._should_include_tool(tool): + # Apply disambiguation if needed + processed_tool = self._apply_disambiguation(tool) + self._tools.append(processed_tool) + + # Check if we've reached max_tools limit + if max_tools_limit is not None and len(self._tools) >= max_tools_limit: + logger.debug("max_tools_reached=<%d> | stopping pagination early", len(self._tools)) + return self._tools + + logger.debug( + "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", + page_count, + len(paginated_tools), + len(self._tools), + ) + + pagination_token = paginated_tools.pagination_token + page_count += 1 + + if pagination_token is None: + break + + logger.debug("final_tools=<%d> | loading complete", len(self._tools)) + + return self._tools + + def _should_include_tool(self, tool: MCPAgentTool) -> bool: + """Check if a tool should be included based on allowed/rejected filters.""" + if not self._tool_filters: + return True + + # Apply allowed filter + if "allowed" in self._tool_filters: + if not self._matches_patterns(tool, self._tool_filters["allowed"]): + return False + + # Apply rejected filter + if "rejected" in self._tool_filters: + if self._matches_patterns(tool, self._tool_filters["rejected"]): + return False + + return True + + def _apply_disambiguation(self, tool: MCPAgentTool) -> MCPAgentTool: + """Apply disambiguation to a single tool if needed.""" + if not self._disambiguator: + return tool + + # Create new tool with disambiguated agent name but preserve original MCP name + old_name = tool.tool_name + new_agent_name = f"{self._disambiguator}_{tool.mcp_tool.name}" + new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, agent_facing_tool_name=new_agent_name) + logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name) + return new_tool + + def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPattern]) -> bool: + """Check if tool matches any of the given patterns.""" + for pattern in patterns: + if callable(pattern): + if pattern(tool): + return True + elif hasattr(pattern, "match") and hasattr(pattern, "pattern"): + if pattern.match(tool.tool_name): + return True + elif isinstance(pattern, str): + if pattern == tool.tool_name: + return True + return False + + async def cleanup(self) -> None: + """Clean up the MCP client connection.""" + if not self._started: + return + + logger.debug("cleaning up MCP client") + try: + logger.debug("stopping MCP client") + self._client.stop(None, None, None) + logger.debug("MCP client stopped successfully") + except Exception as e: + logger.error("error=<%s> | failed to cleanup MCP client", e) + raise ToolProviderException(f"Failed to cleanup MCP client: {e}") from e + + # Only reset state if cleanup succeeded + self._started = False + self._tools = None + logger.debug("MCP client cleanup complete") diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py new file mode 100644 index 000000000..0e8d54dfc --- /dev/null +++ b/src/strands/experimental/tools/tool_provider.py @@ -0,0 +1,32 @@ +"""Tool provider interface.""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Sequence + +if TYPE_CHECKING: + from ...types.tools import AgentTool + + +class ToolProvider(ABC): + """Interface for providing tools with lifecycle management. + + Provides a way to load a collection of tools and clean them up + when done, with lifecycle managed by the agent. + """ + + @abstractmethod + async def load_tools(self) -> Sequence["AgentTool"]: + """Load and return the tools in this provider. + + Returns: + List of tools that are ready to use. + """ + pass + + @abstractmethod + async def cleanup(self) -> None: + """Clean up resources used by the tools in this provider. + + Should be called when the tools are no longer needed. + """ + pass diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 0dbd85d81..903e41522 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -3,14 +3,13 @@ Provides minimal foundation for multi-agent patterns (Swarm, Graph). """ -import asyncio import warnings from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum from typing import Any, Union +from .._async import run_async from ..agent import AgentResult from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage @@ -116,9 +115,4 @@ def __call__( invocation_state.update(kwargs) warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) - def execute() -> MultiAgentResult: - return asyncio.run(self.invoke_async(task, invocation_state)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state)) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 1dbbfc3af..0aaa6c7a3 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -18,12 +18,12 @@ import copy import logging import time -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Any, Callable, Optional, Tuple from opentelemetry import trace as trace_api +from .._async import run_async from ..agent import Agent from ..agent.state import AgentState from ..telemetry import get_tracer @@ -399,12 +399,7 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> GraphResult: - return asyncio.run(self.invoke_async(task, invocation_state)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 7542b1b85..3d9dc00c8 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -17,13 +17,14 @@ import json import logging import time -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Any, Callable, Tuple from opentelemetry import trace as trace_api -from ..agent import Agent, AgentResult +from .._async import run_async +from ..agent import Agent +from ..agent.agent_result import AgentResult from ..agent.state import AgentState from ..telemetry import get_tracer from ..tools.decorator import tool @@ -254,12 +255,7 @@ def __call__( if invocation_state is None: invocation_state = {} - def execute() -> SwarmResult: - return asyncio.run(self.invoke_async(task, invocation_state)) - - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) - return future.result() + return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index acc48443c..91ec6216a 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -28,26 +28,29 @@ class MCPAgentTool(AgentTool): seamlessly within the agent framework. """ - def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient") -> None: + def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", agent_facing_tool_name: str | None = None) -> None: """Initialize a new MCPAgentTool instance. Args: mcp_tool: The MCP tool to adapt mcp_client: The MCP server connection to use for tool invocation + agent_facing_tool_name: Optional name to use for the agent tool (for disambiguation) + If None, uses the original MCP tool name """ super().__init__() logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) self.mcp_tool = mcp_tool self.mcp_client = mcp_client + self._agent_tool_name = agent_facing_tool_name or mcp_tool.name @property def tool_name(self) -> str: """Get the name of the tool. Returns: - str: The name of the MCP tool + str: The agent-facing name of the tool (may be disambiguated) """ - return self.mcp_tool.name + return self._agent_tool_name @property def tool_spec(self) -> ToolSpec: @@ -63,7 +66,7 @@ def tool_spec(self) -> ToolSpec: spec: ToolSpec = { "inputSchema": {"json": self.mcp_tool.inputSchema}, - "name": self.mcp_tool.name, + "name": self.tool_name, # Use agent-facing name in spec "description": description, } @@ -100,7 +103,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw result = await self.mcp_client.call_tool_async( tool_use_id=tool_use["toolUseId"], - name=self.tool_name, + name=self.mcp_tool.name, # Use original MCP name for server communication arguments=tool_use["input"], ) yield ToolResultEvent(result) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 3631c9dee..ea71b09a0 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -18,6 +18,8 @@ from strands.tools.decorator import DecoratedFunctionTool +from .._async import run_async +from ..experimental.tools import ToolProvider from ..types.tools import AgentTool, ToolSpec from .loader import load_tool_from_string, load_tools_from_module from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec @@ -36,6 +38,7 @@ def __init__(self) -> None: self.registry: Dict[str, AgentTool] = {} self.dynamic_tools: Dict[str, AgentTool] = {} self.tool_config: Optional[Dict[str, Any]] = None + self.tool_providers: List[ToolProvider] = [] def process_tools(self, tools: List[Any]) -> List[str]: """Process tools list. @@ -118,6 +121,16 @@ def add_tool(tool: Any) -> None: elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): for t in tool: add_tool(t) + + # Case 5: ToolProvider + elif isinstance(tool, ToolProvider): + self.tool_providers.append(tool) + + provider_tools = run_async(tool.load_tools) + + for provider_tool in provider_tools: + self.register_tool(provider_tool) + tool_names.append(provider_tool.tool_name) else: logger.warning("tool=<%s> | unrecognized tool specification", tool) diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 90f2b8d7f..adff7add7 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -75,3 +75,9 @@ class SessionException(Exception): """Exception raised when session operations fail.""" pass + + +class ToolProviderException(Exception): + """Exception raised when a tool provider fails to load or cleanup tools.""" + + pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index b58e5f3fd..c9d0cf221 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -890,8 +890,184 @@ def test_agent_tool_names(tools, agent): assert actual == expected +def test_agent_cleanup(agent): + """Test that agent cleanup method works correctly.""" + # Create mock tool provider + mock_provider = unittest.mock.MagicMock() + mock_provider.cleanup = unittest.mock.AsyncMock() + + # Add provider to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider] + + with unittest.mock.patch("strands.agent.agent.run_async") as mock_run_async: + agent.cleanup() + + # Verify run_async was called once (for cleanup_async) + mock_run_async.assert_called_once() + # Get the function that was passed to run_async and verify it's cleanup_async + called_func = mock_run_async.call_args[0][0] + assert called_func == agent.cleanup_async + + +@pytest.mark.asyncio +async def test_agent_cleanup_async(agent): + """Test that agent cleanup_async method works correctly.""" + # Create mock tool provider + mock_provider = unittest.mock.MagicMock() + mock_provider.cleanup = unittest.mock.AsyncMock() + + # Add provider to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider] + + await agent.cleanup_async() + + # Verify provider cleanup was called + mock_provider.cleanup.assert_called_once() + # Verify cleanup was marked as called + assert agent._cleanup_called is True + + +@pytest.mark.asyncio +async def test_agent_cleanup_async_handles_exceptions(agent): + """Test that agent cleanup_async handles exceptions gracefully.""" + # Create mock tool providers, one that raises an exception + mock_provider1 = unittest.mock.MagicMock() + mock_provider1.cleanup = unittest.mock.AsyncMock() + mock_provider2 = unittest.mock.MagicMock() + mock_provider2.cleanup = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) + + # Add providers to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider1, mock_provider2] + + # Should not raise exception despite provider2 failing + await agent.cleanup_async() + + # Verify both providers were attempted + mock_provider1.cleanup.assert_called_once() + mock_provider2.cleanup.assert_called_once() + # Verify cleanup was marked as called + assert agent._cleanup_called is True + + +@pytest.mark.asyncio +async def test_agent_cleanup_async_idempotent(agent): + """Test that calling cleanup_async multiple times is safe.""" + # Create mock tool provider + mock_provider = unittest.mock.MagicMock() + mock_provider.cleanup = unittest.mock.AsyncMock() + + # Add provider to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider] + + # Call cleanup_async twice + await agent.cleanup_async() + await agent.cleanup_async() + + # Verify provider cleanup was only called once due to idempotency + mock_provider.cleanup.assert_called_once() + + +@pytest.mark.asyncio +async def test_agent_cleanup_async_with_no_providers(agent): + """Test that agent cleanup_async works when there are no tool providers.""" + # Ensure no providers + agent.tool_registry.tool_providers = [] + + # Should not raise any exceptions + await agent.cleanup_async() + + # Verify cleanup was marked as called + assert agent._cleanup_called is True + + def test_agent__del__(agent): - del agent + """Test that agent destructor calls cleanup.""" + # Add a mock tool provider so cleanup will be called + mock_provider = unittest.mock.MagicMock() + agent.tool_registry.tool_providers = [mock_provider] + + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + agent.__del__() + mock_cleanup.assert_called_once() + + +def test_agent__del__handles_cleanup_exception(agent): + """Test that agent destructor handles cleanup exceptions.""" + with unittest.mock.patch.object(agent, "cleanup", side_effect=Exception("Cleanup failed")): + # Should not raise exception + agent.__del__() + + +def test_agent_cleanup_idempotent(agent): + """Test that calling cleanup multiple times is safe.""" + # Create mock tool provider + mock_provider = unittest.mock.MagicMock() + mock_provider.cleanup = unittest.mock.AsyncMock() + + # Add provider to agent's tool registry + agent.tool_registry.tool_providers = [mock_provider] + + # Call cleanup twice + agent.cleanup() + agent.cleanup() + + # Verify provider cleanup was only called once due to idempotency + mock_provider.cleanup.assert_called_once() + + +def test_agent__del__emits_warning_for_automatic_cleanup(agent): + """Test that __del__ emits warning when cleanup wasn't called manually.""" + # Add a mock tool provider so cleanup will be called + mock_provider = unittest.mock.MagicMock() + agent.tool_registry.tool_providers = [mock_provider] + + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + agent.__del__() + + # Verify warning was logged + mock_logger.warning.assert_called_once() + warning_call = mock_logger.warning.call_args[0] + assert "Agent cleanup called via __del__" in warning_call[0] + # Verify cleanup was called + mock_cleanup.assert_called_once() + + +def test_agent__del__no_warning_after_manual_cleanup(): + """Test that __del__ doesn't emit warning if cleanup was called manually.""" + # Create a fresh agent for this test + from strands import Agent + + agent = Agent() + + # Call cleanup manually first + with unittest.mock.patch.object(agent, "cleanup_async"): + agent.cleanup() + + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + agent.__del__() + + # Verify no warning was logged + mock_logger.warning.assert_not_called() + + +def test_agent__del__no_warning_when_no_tool_providers(): + """Test that __del__ doesn't emit warning when there are no tool providers.""" + # Create a fresh agent for this test + from strands import Agent + + agent = Agent() + + # Ensure no tool providers + agent.tool_registry.tool_providers = [] + + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + agent.__del__() + + # Verify no warning was logged and cleanup wasn't called + mock_logger.warning.assert_not_called() + mock_cleanup.assert_not_called() def test_agent_init_with_no_model_or_model_id(): diff --git a/tests/strands/experimental/tools/__init__.py b/tests/strands/experimental/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/tools/mcp/__init__.py b/tests/strands/experimental/tools/mcp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py b/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py new file mode 100644 index 000000000..ae28d51c8 --- /dev/null +++ b/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py @@ -0,0 +1,375 @@ +"""Unit tests for MCPToolProvider.""" + +import re +from unittest.mock import MagicMock, patch + +import pytest + +from strands.experimental.tools.mcp import MCPToolProvider, ToolFilters +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_agent_tool import MCPAgentTool +from strands.types import PaginatedList +from strands.types.exceptions import ToolProviderException + + +@pytest.fixture +def mock_mcp_client(): + """Create a mock MCP client.""" + client = MagicMock(spec=MCPClient) + client.start = MagicMock() + client.stop = MagicMock() + client.list_tools_sync = MagicMock() + return client + + +@pytest.fixture +def mock_mcp_tool(): + """Create a mock MCP tool.""" + tool = MagicMock() + tool.name = "test_tool" + return tool + + +@pytest.fixture +def mock_agent_tool(mock_mcp_tool, mock_mcp_client): + """Create a mock MCPAgentTool.""" + agent_tool = MagicMock(spec=MCPAgentTool) + agent_tool.tool_name = "test_tool" + agent_tool.mcp_tool = mock_mcp_tool + agent_tool.mcp_client = mock_mcp_client + return agent_tool + + +def create_mock_tool(name: str) -> MagicMock: + """Helper to create mock tools with specific names.""" + tool = MagicMock(spec=MCPAgentTool) + tool.tool_name = name + tool.mcp_tool = MagicMock() + tool.mcp_tool.name = name + return tool + + +def test_init_with_client_only(mock_mcp_client): + """Test initialization with only client.""" + provider = MCPToolProvider(client=mock_mcp_client) + + assert provider._client is mock_mcp_client + assert provider._tool_filters is None + assert provider._disambiguator is None + assert provider._tools is None + assert provider._started is False + + +def test_init_with_all_parameters(mock_mcp_client): + """Test initialization with all parameters.""" + filters = {"allowed": ["tool1"], "max_tools": 5} + disambiguator = "test_prefix" + + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters, disambiguator=disambiguator) + + assert provider._client is mock_mcp_client + assert provider._tool_filters == filters + assert provider._disambiguator == disambiguator + assert provider._tools is None + assert provider._started is False + + +@pytest.mark.asyncio +async def test_load_tools_starts_client_when_not_started(mock_mcp_client, mock_agent_tool): + """Test that load_tools starts the client when not already started.""" + mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) + + provider = MCPToolProvider(client=mock_mcp_client) + + tools = await provider.load_tools() + + mock_mcp_client.start.assert_called_once() + assert provider._started is True + assert len(tools) == 1 + assert tools[0] is mock_agent_tool + + +@pytest.mark.asyncio +async def test_load_tools_does_not_start_client_when_already_started(mock_mcp_client, mock_agent_tool): + """Test that load_tools does not start client when already started.""" + mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) + + provider = MCPToolProvider(client=mock_mcp_client) + provider._started = True + + tools = await provider.load_tools() + + mock_mcp_client.start.assert_not_called() + assert len(tools) == 1 + + +@pytest.mark.asyncio +async def test_load_tools_raises_exception_on_client_start_failure(mock_mcp_client): + """Test that load_tools raises ToolProviderException when client start fails.""" + mock_mcp_client.start.side_effect = Exception("Client start failed") + + provider = MCPToolProvider(client=mock_mcp_client) + + with pytest.raises(ToolProviderException, match="Failed to start MCP client: Client start failed"): + await provider.load_tools() + + +@pytest.mark.asyncio +async def test_load_tools_caches_tools(mock_mcp_client, mock_agent_tool): + """Test that load_tools caches tools and doesn't reload them.""" + mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) + + provider = MCPToolProvider(client=mock_mcp_client) + + # First call + tools1 = await provider.load_tools() + # Second call + tools2 = await provider.load_tools() + + # Client should only be called once + mock_mcp_client.list_tools_sync.assert_called_once() + assert tools1 is tools2 + + +@pytest.mark.asyncio +async def test_load_tools_handles_pagination(mock_mcp_client, mock_agent_tool): + """Test that load_tools handles pagination correctly.""" + tool1 = MagicMock(spec=MCPAgentTool) + tool1.tool_name = "tool1" + tool2 = MagicMock(spec=MCPAgentTool) + tool2.tool_name = "tool2" + + # Mock pagination: first page returns tool1 with next token, second page returns tool2 with no token + mock_mcp_client.list_tools_sync.side_effect = [ + PaginatedList([tool1], token="page2"), + PaginatedList([tool2], token=None), + ] + + provider = MCPToolProvider(client=mock_mcp_client) + + tools = await provider.load_tools() + + # Should have called list_tools_sync twice + assert mock_mcp_client.list_tools_sync.call_count == 2 + # First call with no token, second call with "page2" token + mock_mcp_client.list_tools_sync.assert_any_call(None) + mock_mcp_client.list_tools_sync.assert_any_call("page2") + + assert len(tools) == 2 + assert tools[0] is tool1 + assert tools[1] is tool2 + + +@pytest.mark.asyncio +async def test_allowed_filter_string_match(mock_mcp_client): + """Test allowed filter with string matching.""" + tool1 = create_mock_tool("allowed_tool") + tool2 = create_mock_tool("rejected_tool") + + mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) + + filters: ToolFilters = {"allowed": ["allowed_tool"]} + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) + + tools = await provider.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "allowed_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_regex_match(mock_mcp_client): + """Test allowed filter with regex matching.""" + tool1 = create_mock_tool("echo_tool") + tool2 = create_mock_tool("other_tool") + + mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) + + filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) + + tools = await provider.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "echo_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_callable_match(mock_mcp_client): + """Test allowed filter with callable matching.""" + tool1 = create_mock_tool("short") + tool2 = create_mock_tool("very_long_tool_name") + + mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) + + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 10 + + filters: ToolFilters = {"allowed": [short_names_only]} + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) + + tools = await provider.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "short" + + +@pytest.mark.asyncio +async def test_rejected_filter(mock_mcp_client): + """Test rejected filter functionality.""" + tool1 = create_mock_tool("good_tool") + tool2 = create_mock_tool("bad_tool") + + mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) + + filters: ToolFilters = {"rejected": ["bad_tool"]} + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) + + tools = await provider.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "good_tool" + + +@pytest.mark.asyncio +async def test_max_tools_filter(mock_mcp_client): + """Test max_tools filter functionality.""" + tools_list = [create_mock_tool(f"tool_{i}") for i in range(5)] + + mock_mcp_client.list_tools_sync.return_value = PaginatedList(tools_list) + + filters: ToolFilters = {"max_tools": 3} + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) + + tools = await provider.load_tools() + + assert len(tools) == 3 + assert all(tool.tool_name.startswith("tool_") for tool in tools) + + +@pytest.mark.asyncio +async def test_disambiguator_renames_tools(mock_mcp_client): + """Test that disambiguator properly renames tools.""" + original_tool = MagicMock(spec=MCPAgentTool) + original_tool.tool_name = "original_name" + original_tool.mcp_tool = MagicMock() + original_tool.mcp_tool.name = "original_name" + original_tool.mcp_client = mock_mcp_client + + mock_mcp_client.list_tools_sync.return_value = PaginatedList([original_tool]) + + with patch("strands.experimental.tools.mcp.mcp_tool_provider.MCPAgentTool") as mock_agent_tool_class: + new_tool = MagicMock(spec=MCPAgentTool) + new_tool.tool_name = "prefix_original_name" + mock_agent_tool_class.return_value = new_tool + + provider = MCPToolProvider(client=mock_mcp_client, disambiguator="prefix") + + tools = await provider.load_tools() + + # Should create new MCPAgentTool with prefixed name + mock_agent_tool_class.assert_called_once_with( + original_tool.mcp_tool, original_tool.mcp_client, agent_facing_tool_name="prefix_original_name" + ) + + assert len(tools) == 1 + assert tools[0] is new_tool + + +@pytest.mark.asyncio +async def test_cleanup_stops_client_when_started(mock_mcp_client): + """Test that cleanup stops the client when started.""" + provider = MCPToolProvider(client=mock_mcp_client) + provider._started = True + provider._tools = [MagicMock()] + + await provider.cleanup() + + mock_mcp_client.stop.assert_called_once_with(None, None, None) + assert provider._started is False + assert provider._tools is None + + +@pytest.mark.asyncio +async def test_cleanup_does_nothing_when_not_started(mock_mcp_client): + """Test that cleanup does nothing when not started.""" + provider = MCPToolProvider(client=mock_mcp_client) + provider._started = False + + await provider.cleanup() + + mock_mcp_client.stop.assert_not_called() + assert provider._started is False + + +@pytest.mark.asyncio +async def test_cleanup_raises_exception_on_client_stop_failure(mock_mcp_client): + """Test that cleanup raises ToolProviderException when client stop fails.""" + mock_mcp_client.stop.side_effect = Exception("Client stop failed") + + provider = MCPToolProvider(client=mock_mcp_client) + provider._started = True + + with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Client stop failed"): + await provider.cleanup() + + # State is not reset when cleanup fails + assert provider._started is True + assert provider._tools is None + + +@pytest.mark.asyncio +async def test_cleanup_does_not_reset_state_on_exception(mock_mcp_client): + """Test that cleanup does not reset state when exception occurs.""" + mock_mcp_client.stop.side_effect = Exception("Client stop failed") + + provider = MCPToolProvider(client=mock_mcp_client) + provider._started = True + mock_tool = MagicMock() + provider._tools = [mock_tool] + + with pytest.raises(ToolProviderException): + await provider.cleanup() + + # State should not be reset when exception occurs + assert provider._started is True + assert provider._tools == [mock_tool] + + +@pytest.mark.asyncio +async def test_load_tools_with_empty_tool_list(mock_mcp_client): + """Test load_tools with empty tool list from server.""" + mock_mcp_client.list_tools_sync.return_value = PaginatedList([]) + + provider = MCPToolProvider(client=mock_mcp_client) + + tools = await provider.load_tools() + + assert len(tools) == 0 + assert provider._started is True + + +@pytest.mark.asyncio +async def test_load_tools_with_no_filters(mock_mcp_client, mock_agent_tool): + """Test load_tools with no filters applied.""" + mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) + + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=None) + + tools = await provider.load_tools() + + assert len(tools) == 1 + assert tools[0] is mock_agent_tool + + +@pytest.mark.asyncio +async def test_load_tools_with_empty_filters(mock_mcp_client, mock_agent_tool): + """Test load_tools with empty filters dict.""" + mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) + + provider = MCPToolProvider(client=mock_mcp_client, tool_filters={}) + + tools = await provider.load_tools() + + assert len(tools) == 1 + assert tools[0] is mock_agent_tool diff --git a/tests/strands/test_async.py b/tests/strands/test_async.py new file mode 100644 index 000000000..2a98a953c --- /dev/null +++ b/tests/strands/test_async.py @@ -0,0 +1,25 @@ +"""Tests for _async module.""" + +import pytest + +from strands._async import run_async + + +def test_run_async_with_return_value(): + """Test run_async returns correct value.""" + + async def async_with_value(): + return 42 + + result = run_async(async_with_value) + assert result == 42 + + +def test_run_async_exception_propagation(): + """Test that exceptions are properly propagated.""" + + async def async_with_exception(): + raise ValueError("test exception") + + with pytest.raises(ValueError, match="test exception"): + run_async(async_with_exception) diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py new file mode 100644 index 000000000..f9f9c9ce0 --- /dev/null +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -0,0 +1,202 @@ +"""Unit tests for ToolRegistry ToolProvider functionality.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from strands.experimental.tools.tool_provider import ToolProvider +from strands.tools.registry import ToolRegistry +from strands.types.tools import AgentTool + + +class MockToolProvider(ToolProvider): + """Mock ToolProvider for testing.""" + + def __init__(self, tools=None, cleanup_error=None): + self._tools = tools or [] + self._cleanup_error = cleanup_error + self.cleanup_called = False + + async def load_tools(self): + return self._tools + + async def cleanup(self): + self.cleanup_called = True + if self._cleanup_error: + raise self._cleanup_error + + +class TestToolRegistryToolProvider: + """Test ToolRegistry integration with ToolProvider.""" + + def test_process_tools_with_tool_provider(self): + """Test that process_tools handles ToolProvider correctly.""" + # Create mock tools + mock_tool1 = MagicMock(spec=AgentTool) + mock_tool1.tool_name = "provider_tool_1" + mock_tool2 = MagicMock(spec=AgentTool) + mock_tool2.tool_name = "provider_tool_2" + + # Create mock provider + provider = MockToolProvider([mock_tool1, mock_tool2]) + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + # Mock run_async to return the tools directly + mock_run_async.return_value = [mock_tool1, mock_tool2] + + tool_names = registry.process_tools([provider]) + + # Verify run_async was called with the provider's load_tools method + mock_run_async.assert_called_once() + + # Verify tools were registered + assert "provider_tool_1" in tool_names + assert "provider_tool_2" in tool_names + assert len(tool_names) == 2 + + # Verify provider was tracked + assert provider in registry.tool_providers + + # Verify tools are in registry + assert registry.registry["provider_tool_1"] is mock_tool1 + assert registry.registry["provider_tool_2"] is mock_tool2 + + def test_process_tools_with_multiple_providers(self): + """Test that process_tools handles multiple ToolProviders.""" + # Create mock tools for first provider + mock_tool1 = MagicMock(spec=AgentTool) + mock_tool1.tool_name = "provider1_tool" + provider1 = MockToolProvider([mock_tool1]) + + # Create mock tools for second provider + mock_tool2 = MagicMock(spec=AgentTool) + mock_tool2.tool_name = "provider2_tool" + provider2 = MockToolProvider([mock_tool2]) + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + # Mock run_async to return appropriate tools for each call + mock_run_async.side_effect = [[mock_tool1], [mock_tool2]] + + tool_names = registry.process_tools([provider1, provider2]) + + # Verify run_async was called twice + assert mock_run_async.call_count == 2 + + # Verify all tools were registered + assert "provider1_tool" in tool_names + assert "provider2_tool" in tool_names + assert len(tool_names) == 2 + + # Verify both providers were tracked + assert provider1 in registry.tool_providers + assert provider2 in registry.tool_providers + assert len(registry.tool_providers) == 2 + + def test_process_tools_with_mixed_tools_and_providers(self): + """Test that process_tools handles mix of regular tools and providers.""" + # Create regular tool + regular_tool = MagicMock(spec=AgentTool) + regular_tool.tool_name = "regular_tool" + + # Create provider tool + provider_tool = MagicMock(spec=AgentTool) + provider_tool.tool_name = "provider_tool" + provider = MockToolProvider([provider_tool]) + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + mock_run_async.return_value = [provider_tool] + + tool_names = registry.process_tools([regular_tool, provider]) + + # Verify both tools were registered + assert "regular_tool" in tool_names + assert "provider_tool" in tool_names + assert len(tool_names) == 2 + + # Verify only provider was tracked + assert provider in registry.tool_providers + assert len(registry.tool_providers) == 1 + + def test_process_tools_with_empty_provider(self): + """Test that process_tools handles provider with no tools.""" + provider = MockToolProvider([]) # Empty tools list + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + mock_run_async.return_value = [] + + tool_names = registry.process_tools([provider]) + + # Verify no tools were registered + assert not tool_names + + # Verify provider was still tracked + assert provider in registry.tool_providers + + def test_tool_providers_public_access(self): + """Test that tool_providers can be accessed directly.""" + provider1 = MockToolProvider() + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry.tool_providers = [provider1, provider2] + + # Verify direct access works + assert len(registry.tool_providers) == 2 + assert provider1 in registry.tool_providers + assert provider2 in registry.tool_providers + + def test_tool_providers_empty_by_default(self): + """Test that tool_providers is empty by default.""" + registry = ToolRegistry() + + assert not registry.tool_providers + assert isinstance(registry.tool_providers, list) + + def test_process_tools_provider_load_exception(self): + """Test that process_tools handles exceptions from provider.load_tools().""" + provider = MockToolProvider() + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + # Make load_tools raise an exception + mock_run_async.side_effect = Exception("Load tools failed") + + # Should raise the exception from load_tools + with pytest.raises(Exception, match="Load tools failed"): + registry.process_tools([provider]) + + # Provider should still be tracked even if load_tools failed + assert provider in registry.tool_providers + + def test_tool_provider_tracking_persistence(self): + """Test that tool providers are tracked across multiple process_tools calls.""" + provider1 = MockToolProvider([MagicMock(spec=AgentTool, tool_name="tool1")]) + provider2 = MockToolProvider([MagicMock(spec=AgentTool, tool_name="tool2")]) + + registry = ToolRegistry() + + with patch("strands.tools.registry.run_async") as mock_run_async: + mock_run_async.side_effect = [ + [MagicMock(spec=AgentTool, tool_name="tool1")], + [MagicMock(spec=AgentTool, tool_name="tool2")], + ] + + # Process first provider + registry.process_tools([provider1]) + assert len(registry.tool_providers) == 1 + assert provider1 in registry.tool_providers + + # Process second provider + registry.process_tools([provider2]) + assert len(registry.tool_providers) == 2 + assert provider1 in registry.tool_providers + assert provider2 in registry.tool_providers diff --git a/tests_integ/test_mcp_tool_provider.py b/tests_integ/test_mcp_tool_provider.py new file mode 100644 index 000000000..4d6a39329 --- /dev/null +++ b/tests_integ/test_mcp_tool_provider.py @@ -0,0 +1,171 @@ +"""Integration tests for MCPToolProvider with real MCP server.""" + +import logging +import re + +import pytest +from mcp import StdioServerParameters, stdio_client + +from strands import Agent +from strands.experimental.tools.mcp import MCPToolProvider, ToolFilters +from strands.tools.mcp import MCPClient +from strands.types.exceptions import ToolProviderException + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger(__name__) + + +def test_mcp_tool_provider_filters(): + """Test MCPToolProvider with various filter combinations.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + # Test string filter, regex filter, callable filter, max_tools, and disambiguator + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 20 # Allow most tools + + filters: ToolFilters = { + "allowed": ["echo", re.compile(r"echo_with_.*"), short_names_only], + "rejected": ["echo_with_delay"], + "max_tools": 2, + } + + provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, disambiguator="test") + agent = Agent(tools=[provider]) + tool_names = agent.tool_names + + # Should have 2 tools max, with test_ prefix, no delay tool + assert len(tool_names) == 2 + assert "echo_with_delay" not in [name.replace("test_", "") for name in tool_names] + assert all(name.startswith("test_") for name in tool_names) + + agent.cleanup() + + +def test_mcp_tool_provider_execution(): + """Test that MCPToolProvider works with agent execution.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + filters: ToolFilters = {"allowed": ["echo"]} + provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, disambiguator="filtered") + agent = Agent( + tools=[provider], + ) + + # Verify the filtered tool exists + assert "filtered_echo" in agent.tool_names + + # # Test direct tool call to verify it works (use correct parameter name from echo server) + tool_result = agent.tool.filtered_echo(to_echo="Hello World") + assert "Hello World" in str(tool_result) + + # # Test agent execution using the tool + result = agent("Use the filtered_echo tool to echo whats inside the tags <>Integration Test") + assert "Integration Test" in str(result) + + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].call_count == 1 + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].success_count == 1 + + agent.cleanup() + + +def test_mcp_tool_provider_reuse(): + """Test that a single MCPToolProvider can be used across multiple agents.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + filters: ToolFilters = {"allowed": ["echo"]} + provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, disambiguator="shared") + + # Create first agent with the provider + agent1 = Agent(tools=[provider]) + assert "shared_echo" in agent1.tool_names + + # Test first agent (use correct parameter name from echo server) + result1 = agent1.tool.shared_echo(to_echo="Agent 1") + assert "Agent 1" in str(result1) + + # Create second agent with the same provider + agent2 = Agent(tools=[provider]) + assert "shared_echo" in agent2.tool_names + + # Test second agent (use correct parameter name from echo server) + result2 = agent2.tool.shared_echo(to_echo="Agent 2") + assert "Agent 2" in str(result2) + + # Both agents should have the same tool count + assert len(agent1.tool_names) == len(agent2.tool_names) + assert agent1.tool_names == agent2.tool_names + + agent1.cleanup() + agent2.cleanup() + + +def test_mcp_tool_provider_multiple_servers(): + """Test MCPToolProvider with multiple MCP servers simultaneously.""" + # Create two separate MCP clients + client1 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + client2 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + # Create providers with different disambiguators + provider1 = MCPToolProvider(client=client1, tool_filters={"allowed": ["echo"]}, disambiguator="server1") + # Use correct tool name from echo_server.py + provider2 = MCPToolProvider( + client=client2, tool_filters={"allowed": ["echo_with_structured_content"]}, disambiguator="server2" + ) + + # Create agent with both providers + agent = Agent(tools=[provider1, provider2]) + + # Should have tools from both servers with different prefixes + assert "server1_echo" in agent.tool_names + assert "server2_echo_with_structured_content" in agent.tool_names + assert len(agent.tool_names) == 2 + + # Test tools from both servers work + result1 = agent.tool.server1_echo(to_echo="From Server 1") + assert "From Server 1" in str(result1) + + result2 = agent.tool.server2_echo_with_structured_content(to_echo="From Server 2") + assert "From Server 2" in str(result2) + + agent.cleanup() + + +def test_mcp_tool_provider_server_startup_failure(): + """Test that MCPToolProvider handles server startup failure gracefully without hanging.""" + # Create client with invalid command that will fail to start + failing_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="nonexistent_command", args=["--invalid"])), + startup_timeout=2, # Short timeout to avoid hanging + ) + + provider = MCPToolProvider(client=failing_client) + + # Should raise ToolProviderException when trying to load tools + with pytest.raises(ToolProviderException, match="Failed to start MCP client"): + Agent(tools=[provider]) + + +def test_mcp_tool_provider_server_connection_timeout(): + """Test that MCPToolProvider times out gracefully when server hangs during startup.""" + # Create client that will hang during connection + hanging_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="sleep", args=["10"])), # Sleep for 10 seconds + startup_timeout=1, # 1 second timeout + ) + + provider = MCPToolProvider(client=hanging_client) + + # Should raise ToolProviderException due to timeout + with pytest.raises(ToolProviderException, match="Failed to start MCP client"): + Agent(tools=[provider]) From 9dec5dfa165c2448530232bd232163802d506344 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 26 Sep 2025 14:48:58 -0300 Subject: [PATCH 02/35] remove max_tools, add kwargs --- src/strands/agent/agent.py | 3 ++ .../tools/mcp/mcp_tool_provider.py | 49 ++++++++----------- .../experimental/tools/tool_provider.py | 12 +++-- tests/strands/agent/test_agent.py | 12 +++++ .../tools/mcp/test_mcp_tool_provider.py | 32 +++--------- tests_integ/test_mcp_tool_provider.py | 14 +++--- 6 files changed, 59 insertions(+), 63 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 1eab8b639..6600d311a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -541,6 +541,9 @@ def cleanup(self) -> None: Note: This method uses a "belt and braces" approach with automatic cleanup through __del__ as a fallback, but explicit cleanup is recommended. """ + if self._cleanup_called: + return + run_async(self.cleanup_async) async def cleanup_async(self) -> None: diff --git a/src/strands/experimental/tools/mcp/mcp_tool_provider.py b/src/strands/experimental/tools/mcp/mcp_tool_provider.py index 610d642bc..44e9fb61f 100644 --- a/src/strands/experimental/tools/mcp/mcp_tool_provider.py +++ b/src/strands/experimental/tools/mcp/mcp_tool_provider.py @@ -1,7 +1,7 @@ """MCP Tool Provider implementation.""" import logging -from typing import Callable, Optional, Pattern, Sequence, Union +from typing import Any, Callable, Optional, Pattern, Sequence, Union from typing_extensions import TypedDict @@ -23,37 +23,39 @@ class ToolFilters(TypedDict, total=False): Tools are filtered in this order: 1. If 'allowed' is specified, only tools matching these patterns are included 2. Tools matching 'rejected' patterns are then excluded - 3. If the result exceeds 'max_tools', it's truncated """ allowed: list[_ToolFilterPattern] rejected: list[_ToolFilterPattern] - max_tools: int class MCPToolProvider(ToolProvider): """Tool provider for MCP clients with managed lifecycle.""" def __init__( - self, *, client: MCPClient, tool_filters: Optional[ToolFilters] = None, disambiguator: Optional[str] = None + self, + *, + client: MCPClient, + tool_filters: Optional[ToolFilters] = None, + prefix: Optional[str] = None, + **kwargs: Any, ) -> None: """Initialize with an MCP client. Args: client: The MCP client to manage. tool_filters: Optional filters to apply to tools. - disambiguator: Optional prefix for tool names. + prefix: Optional prefix for tool names. + **kwargs: Additional arguments for future compatibility. """ - logger.debug( - "tool_filters=<%s>, disambiguator=<%s> | initializing MCPToolProvider", tool_filters, disambiguator - ) + logger.debug("tool_filters=<%s>, prefix=<%s> | initializing MCPToolProvider", tool_filters, prefix) self._client = client self._tool_filters = tool_filters - self._disambiguator = disambiguator + self._prefix = prefix self._tools: Optional[list[MCPAgentTool]] = None # None = not loaded yet, [] = loaded but empty self._started = False - async def load_tools(self) -> Sequence[AgentTool]: + async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: """Load and return tools from the MCP client. Returns: @@ -77,12 +79,6 @@ async def load_tools(self) -> Sequence[AgentTool]: pagination_token = None page_count = 0 - # Determine max_tools limit for early termination - max_tools_limit = None - if self._tool_filters and "max_tools" in self._tool_filters: - max_tools_limit = self._tool_filters["max_tools"] - logger.debug("max_tools_limit=<%d> | will stop when reached", max_tools_limit) - while True: logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) paginated_tools = self._client.list_tools_sync(pagination_token) @@ -91,15 +87,10 @@ async def load_tools(self) -> Sequence[AgentTool]: for tool in paginated_tools: # Apply filters if self._should_include_tool(tool): - # Apply disambiguation if needed - processed_tool = self._apply_disambiguation(tool) + # Apply prefix if needed + processed_tool = self._apply_prefix(tool) self._tools.append(processed_tool) - # Check if we've reached max_tools limit - if max_tools_limit is not None and len(self._tools) >= max_tools_limit: - logger.debug("max_tools_reached=<%d> | stopping pagination early", len(self._tools)) - return self._tools - logger.debug( "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", page_count, @@ -134,14 +125,14 @@ def _should_include_tool(self, tool: MCPAgentTool) -> bool: return True - def _apply_disambiguation(self, tool: MCPAgentTool) -> MCPAgentTool: - """Apply disambiguation to a single tool if needed.""" - if not self._disambiguator: + def _apply_prefix(self, tool: MCPAgentTool) -> MCPAgentTool: + """Apply prefix to a single tool if needed.""" + if not self._prefix: return tool - # Create new tool with disambiguated agent name but preserve original MCP name + # Create new tool with prefixed agent name but preserve original MCP name old_name = tool.tool_name - new_agent_name = f"{self._disambiguator}_{tool.mcp_tool.name}" + new_agent_name = f"{self._prefix}_{tool.mcp_tool.name}" new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, agent_facing_tool_name=new_agent_name) logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name) return new_tool @@ -160,7 +151,7 @@ def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPatter return True return False - async def cleanup(self) -> None: + async def cleanup(self, **kwargs: Any) -> None: """Clean up the MCP client connection.""" if not self._started: return diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py index 0e8d54dfc..5a2bc94c3 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/experimental/tools/tool_provider.py @@ -1,7 +1,7 @@ """Tool provider interface.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence if TYPE_CHECKING: from ...types.tools import AgentTool @@ -15,18 +15,24 @@ class ToolProvider(ABC): """ @abstractmethod - async def load_tools(self) -> Sequence["AgentTool"]: + async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]: """Load and return the tools in this provider. + Args: + **kwargs: Additional arguments for future compatibility. + Returns: List of tools that are ready to use. """ pass @abstractmethod - async def cleanup(self) -> None: + async def cleanup(self, **kwargs: Any) -> None: """Clean up resources used by the tools in this provider. + Args: + **kwargs: Additional arguments for future compatibility. + Should be called when the tools are no longer needed. """ pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index c9d0cf221..4b6e19c5b 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1015,6 +1015,18 @@ def test_agent_cleanup_idempotent(agent): mock_provider.cleanup.assert_called_once() +def test_agent_cleanup_early_return_avoids_thread_spawn(agent): + """Test that cleanup returns early when already called, avoiding thread spawn cost.""" + # Mark cleanup as already called + agent._cleanup_called = True + + with unittest.mock.patch("strands.agent.agent.run_async") as mock_run_async: + agent.cleanup() + + # Verify run_async was not called since cleanup already happened + mock_run_async.assert_not_called() + + def test_agent__del__emits_warning_for_automatic_cleanup(agent): """Test that __del__ emits warning when cleanup wasn't called manually.""" # Add a mock tool provider so cleanup will be called diff --git a/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py b/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py index ae28d51c8..3576fd0b3 100644 --- a/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py +++ b/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py @@ -55,21 +55,21 @@ def test_init_with_client_only(mock_mcp_client): assert provider._client is mock_mcp_client assert provider._tool_filters is None - assert provider._disambiguator is None + assert provider._prefix is None assert provider._tools is None assert provider._started is False def test_init_with_all_parameters(mock_mcp_client): """Test initialization with all parameters.""" - filters = {"allowed": ["tool1"], "max_tools": 5} - disambiguator = "test_prefix" + filters = {"allowed": ["tool1"]} + prefix = "test_prefix" - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters, disambiguator=disambiguator) + provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters, prefix=prefix) assert provider._client is mock_mcp_client assert provider._tool_filters == filters - assert provider._disambiguator == disambiguator + assert provider._prefix == prefix assert provider._tools is None assert provider._started is False @@ -232,24 +232,8 @@ async def test_rejected_filter(mock_mcp_client): @pytest.mark.asyncio -async def test_max_tools_filter(mock_mcp_client): - """Test max_tools filter functionality.""" - tools_list = [create_mock_tool(f"tool_{i}") for i in range(5)] - - mock_mcp_client.list_tools_sync.return_value = PaginatedList(tools_list) - - filters: ToolFilters = {"max_tools": 3} - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) - - tools = await provider.load_tools() - - assert len(tools) == 3 - assert all(tool.tool_name.startswith("tool_") for tool in tools) - - -@pytest.mark.asyncio -async def test_disambiguator_renames_tools(mock_mcp_client): - """Test that disambiguator properly renames tools.""" +async def test_prefix_renames_tools(mock_mcp_client): + """Test that prefix properly renames tools.""" original_tool = MagicMock(spec=MCPAgentTool) original_tool.tool_name = "original_name" original_tool.mcp_tool = MagicMock() @@ -263,7 +247,7 @@ async def test_disambiguator_renames_tools(mock_mcp_client): new_tool.tool_name = "prefix_original_name" mock_agent_tool_class.return_value = new_tool - provider = MCPToolProvider(client=mock_mcp_client, disambiguator="prefix") + provider = MCPToolProvider(client=mock_mcp_client, prefix="prefix") tools = await provider.load_tools() diff --git a/tests_integ/test_mcp_tool_provider.py b/tests_integ/test_mcp_tool_provider.py index 4d6a39329..5b7bb3ed1 100644 --- a/tests_integ/test_mcp_tool_provider.py +++ b/tests_integ/test_mcp_tool_provider.py @@ -22,7 +22,7 @@ def test_mcp_tool_provider_filters(): lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) ) - # Test string filter, regex filter, callable filter, max_tools, and disambiguator + # Test string filter, regex filter, callable filter, and prefix def short_names_only(tool) -> bool: return len(tool.tool_name) <= 20 # Allow most tools @@ -32,7 +32,7 @@ def short_names_only(tool) -> bool: "max_tools": 2, } - provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, disambiguator="test") + provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, prefix="test") agent = Agent(tools=[provider]) tool_names = agent.tool_names @@ -51,7 +51,7 @@ def test_mcp_tool_provider_execution(): ) filters: ToolFilters = {"allowed": ["echo"]} - provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, disambiguator="filtered") + provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, prefix="filtered") agent = Agent( tools=[provider], ) @@ -80,7 +80,7 @@ def test_mcp_tool_provider_reuse(): ) filters: ToolFilters = {"allowed": ["echo"]} - provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, disambiguator="shared") + provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, prefix="shared") # Create first agent with the provider agent1 = Agent(tools=[provider]) @@ -116,11 +116,11 @@ def test_mcp_tool_provider_multiple_servers(): lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) ) - # Create providers with different disambiguators - provider1 = MCPToolProvider(client=client1, tool_filters={"allowed": ["echo"]}, disambiguator="server1") + # Create providers with different prefixes + provider1 = MCPToolProvider(client=client1, tool_filters={"allowed": ["echo"]}, prefix="server1") # Use correct tool name from echo_server.py provider2 = MCPToolProvider( - client=client2, tool_filters={"allowed": ["echo_with_structured_content"]}, disambiguator="server2" + client=client2, tool_filters={"allowed": ["echo_with_structured_content"]}, prefix="server2" ) # Create agent with both providers From 7833a49fad0dfdd7cf1e7b63ab91216c30545157 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 8 Oct 2025 16:52:42 -0400 Subject: [PATCH 03/35] mcp_client implements tool_provider --- src/strands/agent/agent.py | 14 +- .../experimental/tools/mcp/__init__.py | 5 - .../tools/mcp/mcp_tool_provider.py | 171 --------- .../experimental/tools/tool_provider.py | 17 +- src/strands/tools/mcp/__init__.py | 4 +- src/strands/tools/mcp/mcp_client.py | 174 ++++++++- src/strands/tools/registry.py | 26 +- tests/strands/agent/test_agent.py | 31 +- .../experimental/tools/mcp/__init__.py | 0 .../tools/mcp/test_mcp_tool_provider.py | 359 ------------------ .../mcp/test_mcp_client_tool_provider.py | 320 ++++++++++++++++ .../tools/test_registry_tool_provider.py | 94 +++++ tests_integ/mcp/test_mcp_tool_provider.py | 184 +++++++++ tests_integ/test_mcp_tool_provider.py | 171 --------- 14 files changed, 824 insertions(+), 746 deletions(-) delete mode 100644 src/strands/experimental/tools/mcp/__init__.py delete mode 100644 src/strands/experimental/tools/mcp/mcp_tool_provider.py delete mode 100644 tests/strands/experimental/tools/mcp/__init__.py delete mode 100644 tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py create mode 100644 tests/strands/tools/mcp/test_mcp_client_tool_provider.py create mode 100644 tests_integ/mcp/test_mcp_tool_provider.py delete mode 100644 tests_integ/test_mcp_tool_provider.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 6600d311a..ee73f2200 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -561,19 +561,7 @@ async def cleanup_async(self) -> None: logger.debug("agent_id=<%s> | cleaning up agent resources", self.agent_id) - for provider in self.tool_registry.tool_providers: - try: - await provider.cleanup() - logger.debug( - "agent_id=<%s>, provider=<%s> | cleaned up tool provider", self.agent_id, type(provider).__name__ - ) - except Exception as e: - logger.warning( - "agent_id=<%s>, provider=<%s>, error=<%s> | failed to cleanup tool provider", - self.agent_id, - type(provider).__name__, - e, - ) + await self.tool_registry.cleanup_async() self._cleanup_called = True logger.debug("agent_id=<%s> | agent cleanup complete", self.agent_id) diff --git a/src/strands/experimental/tools/mcp/__init__.py b/src/strands/experimental/tools/mcp/__init__.py deleted file mode 100644 index ee1ccc542..000000000 --- a/src/strands/experimental/tools/mcp/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Experimental MCP Tool Provider.""" - -from .mcp_tool_provider import MCPToolProvider, ToolFilters - -__all__ = ["MCPToolProvider", "ToolFilters"] diff --git a/src/strands/experimental/tools/mcp/mcp_tool_provider.py b/src/strands/experimental/tools/mcp/mcp_tool_provider.py deleted file mode 100644 index 44e9fb61f..000000000 --- a/src/strands/experimental/tools/mcp/mcp_tool_provider.py +++ /dev/null @@ -1,171 +0,0 @@ -"""MCP Tool Provider implementation.""" - -import logging -from typing import Any, Callable, Optional, Pattern, Sequence, Union - -from typing_extensions import TypedDict - -from ....tools.mcp.mcp_agent_tool import MCPAgentTool -from ....tools.mcp.mcp_client import MCPClient -from ....types.exceptions import ToolProviderException -from ....types.tools import AgentTool -from ..tool_provider import ToolProvider - -logger = logging.getLogger(__name__) - -_ToolFilterCallback = Callable[[AgentTool], bool] -_ToolFilterPattern = Union[str, Pattern[str], _ToolFilterCallback] - - -class ToolFilters(TypedDict, total=False): - """Filters for controlling which MCP tools are loaded and available. - - Tools are filtered in this order: - 1. If 'allowed' is specified, only tools matching these patterns are included - 2. Tools matching 'rejected' patterns are then excluded - """ - - allowed: list[_ToolFilterPattern] - rejected: list[_ToolFilterPattern] - - -class MCPToolProvider(ToolProvider): - """Tool provider for MCP clients with managed lifecycle.""" - - def __init__( - self, - *, - client: MCPClient, - tool_filters: Optional[ToolFilters] = None, - prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Initialize with an MCP client. - - Args: - client: The MCP client to manage. - tool_filters: Optional filters to apply to tools. - prefix: Optional prefix for tool names. - **kwargs: Additional arguments for future compatibility. - """ - logger.debug("tool_filters=<%s>, prefix=<%s> | initializing MCPToolProvider", tool_filters, prefix) - self._client = client - self._tool_filters = tool_filters - self._prefix = prefix - self._tools: Optional[list[MCPAgentTool]] = None # None = not loaded yet, [] = loaded but empty - self._started = False - - async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: - """Load and return tools from the MCP client. - - Returns: - List of tools from the MCP server. - """ - logger.debug("started=<%s>, cached_tools=<%s> | loading tools", self._started, self._tools is not None) - - if not self._started: - try: - logger.debug("starting MCP client") - self._client.start() - self._started = True - logger.debug("MCP client started successfully") - except Exception as e: - logger.error("error=<%s> | failed to start MCP client", e) - raise ToolProviderException(f"Failed to start MCP client: {e}") from e - - if self._tools is None: - logger.debug("loading tools from MCP server") - self._tools = [] - pagination_token = None - page_count = 0 - - while True: - logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) - paginated_tools = self._client.list_tools_sync(pagination_token) - - # Process each tool as we get it - for tool in paginated_tools: - # Apply filters - if self._should_include_tool(tool): - # Apply prefix if needed - processed_tool = self._apply_prefix(tool) - self._tools.append(processed_tool) - - logger.debug( - "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", - page_count, - len(paginated_tools), - len(self._tools), - ) - - pagination_token = paginated_tools.pagination_token - page_count += 1 - - if pagination_token is None: - break - - logger.debug("final_tools=<%d> | loading complete", len(self._tools)) - - return self._tools - - def _should_include_tool(self, tool: MCPAgentTool) -> bool: - """Check if a tool should be included based on allowed/rejected filters.""" - if not self._tool_filters: - return True - - # Apply allowed filter - if "allowed" in self._tool_filters: - if not self._matches_patterns(tool, self._tool_filters["allowed"]): - return False - - # Apply rejected filter - if "rejected" in self._tool_filters: - if self._matches_patterns(tool, self._tool_filters["rejected"]): - return False - - return True - - def _apply_prefix(self, tool: MCPAgentTool) -> MCPAgentTool: - """Apply prefix to a single tool if needed.""" - if not self._prefix: - return tool - - # Create new tool with prefixed agent name but preserve original MCP name - old_name = tool.tool_name - new_agent_name = f"{self._prefix}_{tool.mcp_tool.name}" - new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, agent_facing_tool_name=new_agent_name) - logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name) - return new_tool - - def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPattern]) -> bool: - """Check if tool matches any of the given patterns.""" - for pattern in patterns: - if callable(pattern): - if pattern(tool): - return True - elif hasattr(pattern, "match") and hasattr(pattern, "pattern"): - if pattern.match(tool.tool_name): - return True - elif isinstance(pattern, str): - if pattern == tool.tool_name: - return True - return False - - async def cleanup(self, **kwargs: Any) -> None: - """Clean up the MCP client connection.""" - if not self._started: - return - - logger.debug("cleaning up MCP client") - try: - logger.debug("stopping MCP client") - self._client.stop(None, None, None) - logger.debug("MCP client stopped successfully") - except Exception as e: - logger.error("error=<%s> | failed to cleanup MCP client", e) - raise ToolProviderException(f"Failed to cleanup MCP client: {e}") from e - - # Only reset state if cleanup succeeded - self._started = False - self._tools = None - logger.debug("MCP client cleanup complete") diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py index 5a2bc94c3..5023dd72c 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/experimental/tools/tool_provider.py @@ -27,12 +27,23 @@ async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]: pass @abstractmethod - async def cleanup(self, **kwargs: Any) -> None: - """Clean up resources used by the tools in this provider. + async def add_provider_consumer(self, id: Any, **kwargs: Any) -> None: + """Add a consumer to this tool provider. Args: + id: Unique identifier for the consumer. + **kwargs: Additional arguments for future compatibility. + """ + pass + + @abstractmethod + async def remove_provider_consumer(self, id: Any, **kwargs: Any) -> None: + """Remove a consumer from this tool provider. + + Args: + id: Unique identifier for the consumer. **kwargs: Additional arguments for future compatibility. - Should be called when the tools are no longer needed. + Provider may clean up resources when no consumers remain. """ pass diff --git a/src/strands/tools/mcp/__init__.py b/src/strands/tools/mcp/__init__.py index d95c54fed..cfa841c46 100644 --- a/src/strands/tools/mcp/__init__.py +++ b/src/strands/tools/mcp/__init__.py @@ -7,7 +7,7 @@ """ from .mcp_agent_tool import MCPAgentTool -from .mcp_client import MCPClient +from .mcp_client import MCPClient, ToolFilters from .mcp_types import MCPTransport -__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport"] +__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "ToolFilters"] diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 8148e149a..b2d5887c3 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -16,7 +16,7 @@ from concurrent import futures from datetime import timedelta from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union, cast +from typing import Any, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast import anyio from mcp import ClientSession, ListToolsResult @@ -25,11 +25,13 @@ from mcp.types import EmbeddedResource as MCPEmbeddedResource from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent +from typing_extensions import TypedDict +from ...experimental.tools import ToolProvider from ...types import PaginatedList -from ...types.exceptions import MCPClientInitializationError +from ...types.exceptions import MCPClientInitializationError, ToolProviderException from ...types.media import ImageFormat -from ...types.tools import ToolResultContent, ToolResultStatus +from ...types.tools import AgentTool, ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool from .mcp_instrumentation import mcp_instrumentation from .mcp_types import MCPToolResult, MCPTransport @@ -38,6 +40,22 @@ T = TypeVar("T") +_ToolFilterCallback = Callable[[AgentTool], bool] +_ToolFilterPattern = Union[str, Pattern[str], _ToolFilterCallback] + + +class ToolFilters(TypedDict, total=False): + """Filters for controlling which MCP tools are loaded and available. + + Tools are filtered in this order: + 1. If 'allowed' is specified, only tools matching these patterns are included + 2. Tools matching 'rejected' patterns are then excluded + """ + + allowed: list[_ToolFilterPattern] + rejected: list[_ToolFilterPattern] + + MIME_TO_FORMAT: Dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", @@ -53,7 +71,7 @@ ) -class MCPClient: +class MCPClient(ToolProvider): """Represents a connection to a Model Context Protocol (MCP) server. This class implements a context manager pattern for efficient connection management, @@ -65,15 +83,26 @@ class MCPClient: from MCP tools, it will be returned as the last item in the content array of the ToolResult. """ - def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_timeout: int = 30): + def __init__( + self, + transport_callable: Callable[[], MCPTransport], + *, + startup_timeout: int = 30, + tool_filters: Optional[ToolFilters] = None, + prefix: Optional[str] = None, + ): """Initialize a new MCP Server connection. Args: transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple startup_timeout: Timeout after which MCP server initialization should be cancelled Defaults to 30. + tool_filters: Optional filters to apply to tools. + prefix: Optional prefix for tool names. """ self._startup_timeout = startup_timeout + self._tool_filters = tool_filters + self._prefix = prefix mcp_instrumentation() self._session_id = uuid.uuid4() @@ -87,6 +116,9 @@ def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_ti self._background_thread: threading.Thread | None = None self._background_thread_session: ClientSession | None = None self._background_thread_event_loop: AbstractEventLoop | None = None + self._loaded_tools: list[MCPAgentTool] | None = None + self._tool_provider_started = False + self._consumers: set[Any] = set() def __enter__(self) -> "MCPClient": """Context manager entry point which initializes the MCP server connection. @@ -137,6 +169,92 @@ def start(self) -> "MCPClient": raise MCPClientInitializationError("the client initialization failed") from e return self + # ToolProvider interface methods (experimental, as ToolProvider is experimental) + async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: + """Load and return tools from the MCP server. + + This method implements the ToolProvider interface by loading tools + from the MCP server and caching them for reuse. + + Args: + **kwargs: Additional arguments for future compatibility. + + Returns: + List of AgentTool instances from the MCP server. + """ + logger.debug( + "started=<%s>, cached_tools=<%s> | loading tools", + self._tool_provider_started, + self._loaded_tools is not None, + ) + + if not self._tool_provider_started: + try: + logger.debug("starting MCP client") + self.start() + self._tool_provider_started = True + logger.debug("MCP client started successfully") + except Exception as e: + logger.error("error=<%s> | failed to start MCP client", e) + raise ToolProviderException(f"Failed to start MCP client: {e}") from e + + if self._loaded_tools is None: + logger.debug("loading tools from MCP server") + self._loaded_tools = [] + pagination_token = None + page_count = 0 + + while True: + logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) + paginated_tools = self.list_tools_sync(pagination_token) + + # Process each tool as we get it + for tool in paginated_tools: + # Apply filters + if self._should_include_tool(tool): + # Apply prefix if needed + processed_tool = self._apply_prefix(tool) + self._loaded_tools.append(processed_tool) + + logger.debug( + "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", + page_count, + len(paginated_tools), + len(self._loaded_tools), + ) + + pagination_token = paginated_tools.pagination_token + page_count += 1 + + if pagination_token is None: + break + + logger.debug("final_tools=<%d> | loading complete", len(self._loaded_tools)) + + return self._loaded_tools + + async def add_provider_consumer(self, id: Any, **kwargs: Any) -> None: + """Add a consumer to this tool provider.""" + self._consumers.add(id) + logger.debug("added provider consumer, count=%d", len(self._consumers)) + + async def remove_provider_consumer(self, id: Any, **kwargs: Any) -> None: + """Remove a consumer from this tool provider.""" + self._consumers.discard(id) + logger.debug("removed provider consumer, count=%d", len(self._consumers)) + + if not self._consumers and self._tool_provider_started: + logger.debug("no consumers remaining, cleaning up") + try: + self.stop(None, None, None) + self._tool_provider_started = False + self._loaded_tools = None + except Exception as e: + logger.error("error=<%s> | failed to cleanup MCP client", e) + raise ToolProviderException(f"Failed to cleanup MCP client: {e}") from e + + # MCP-specific methods + def stop( self, exc_type: Optional[BaseException], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] ) -> None: @@ -187,6 +305,9 @@ async def _set_close_event() -> None: self._background_thread_session = None self._background_thread_event_loop = None self._session_id = uuid.uuid4() + self._loaded_tools = None + self._tool_provider_started = False + self._consumers = set() def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedList[MCPAgentTool]: """Synchronously retrieves the list of available tools from the MCP server. @@ -530,5 +651,48 @@ def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures. raise MCPClientInitializationError("the client session was not initialized") return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop) + def _should_include_tool(self, tool: MCPAgentTool) -> bool: + """Check if a tool should be included based on allowed/rejected filters.""" + if not self._tool_filters: + return True + + # Apply allowed filter + if "allowed" in self._tool_filters: + if not self._matches_patterns(tool, self._tool_filters["allowed"]): + return False + + # Apply rejected filter + if "rejected" in self._tool_filters: + if self._matches_patterns(tool, self._tool_filters["rejected"]): + return False + + return True + + def _apply_prefix(self, tool: MCPAgentTool) -> MCPAgentTool: + """Apply prefix to a single tool if needed.""" + if not self._prefix: + return tool + + # Create new tool with prefixed agent name but preserve original MCP name + old_name = tool.tool_name + new_agent_name = f"{self._prefix}_{tool.mcp_tool.name}" + new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, agent_facing_tool_name=new_agent_name) + logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name) + return new_tool + + def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPattern]) -> bool: + """Check if tool matches any of the given patterns.""" + for pattern in patterns: + if callable(pattern): + if pattern(tool): + return True + elif hasattr(pattern, "match") and hasattr(pattern, "pattern"): + if pattern.match(tool.tool_name): + return True + elif isinstance(pattern, str): + if pattern == tool.tool_name: + return True + return False + def _is_session_active(self) -> bool: return self._background_thread is not None and self._background_thread.is_alive() diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index ea71b09a0..52028ee32 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -8,11 +8,12 @@ import logging import os import sys +import uuid import warnings from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Sequence from typing_extensions import TypedDict, cast @@ -39,6 +40,7 @@ def __init__(self) -> None: self.dynamic_tools: Dict[str, AgentTool] = {} self.tool_config: Optional[Dict[str, Any]] = None self.tool_providers: List[ToolProvider] = [] + self._registry_id = str(uuid.uuid4()) def process_tools(self, tools: List[Any]) -> List[str]: """Process tools list. @@ -121,12 +123,17 @@ def add_tool(tool: Any) -> None: elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): for t in tool: add_tool(t) - + # Case 5: ToolProvider elif isinstance(tool, ToolProvider): self.tool_providers.append(tool) - provider_tools = run_async(tool.load_tools) + async def get_tools_and_register_consumer() -> Sequence[AgentTool]: + provider_tools = await tool.load_tools() + await tool.add_provider_consumer(self._registry_id) + return provider_tools + + provider_tools = run_async(get_tools_and_register_consumer) for provider_tool in provider_tools: self.register_tool(provider_tool) @@ -653,3 +660,16 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: logger.warning("tool_name=<%s> | failed to create function tool | %s", name, e) return tools + + async def cleanup_async(self, **kwargs: Any) -> None: + """Clean up all tool providers in this registry.""" + for provider in self.tool_providers: + try: + await provider.remove_provider_consumer(self._registry_id) + logger.debug("provider=<%s> | removed provider consumer", type(provider).__name__) + except Exception as e: + logger.warning( + "provider=<%s>, error=<%s> | failed to remove provider consumer", + type(provider).__name__, + e, + ) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 4b6e19c5b..b342af131 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -914,15 +914,15 @@ async def test_agent_cleanup_async(agent): """Test that agent cleanup_async method works correctly.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.cleanup = unittest.mock.AsyncMock() + mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] await agent.cleanup_async() - # Verify provider cleanup was called - mock_provider.cleanup.assert_called_once() + # Verify provider remove_provider_consumer was called + mock_provider.remove_provider_consumer.assert_called_once_with(agent.tool_registry._registry_id) # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -932,9 +932,9 @@ async def test_agent_cleanup_async_handles_exceptions(agent): """Test that agent cleanup_async handles exceptions gracefully.""" # Create mock tool providers, one that raises an exception mock_provider1 = unittest.mock.MagicMock() - mock_provider1.cleanup = unittest.mock.AsyncMock() + mock_provider1.remove_provider_consumer = unittest.mock.AsyncMock() mock_provider2 = unittest.mock.MagicMock() - mock_provider2.cleanup = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) + mock_provider2.remove_provider_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) # Add providers to agent's tool registry agent.tool_registry.tool_providers = [mock_provider1, mock_provider2] @@ -943,8 +943,8 @@ async def test_agent_cleanup_async_handles_exceptions(agent): await agent.cleanup_async() # Verify both providers were attempted - mock_provider1.cleanup.assert_called_once() - mock_provider2.cleanup.assert_called_once() + mock_provider1.remove_provider_consumer.assert_called_once() + mock_provider2.remove_provider_consumer.assert_called_once() # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -954,7 +954,7 @@ async def test_agent_cleanup_async_idempotent(agent): """Test that calling cleanup_async multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.cleanup = unittest.mock.AsyncMock() + mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -963,8 +963,8 @@ async def test_agent_cleanup_async_idempotent(agent): await agent.cleanup_async() await agent.cleanup_async() - # Verify provider cleanup was only called once due to idempotency - mock_provider.cleanup.assert_called_once() + # Verify provider remove_provider_consumer was only called once due to idempotency + mock_provider.remove_provider_consumer.assert_called_once() @pytest.mark.asyncio @@ -1002,7 +1002,7 @@ def test_agent_cleanup_idempotent(agent): """Test that calling cleanup multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.cleanup = unittest.mock.AsyncMock() + mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -1011,8 +1011,8 @@ def test_agent_cleanup_idempotent(agent): agent.cleanup() agent.cleanup() - # Verify provider cleanup was only called once due to idempotency - mock_provider.cleanup.assert_called_once() + # Verify provider remove_provider_consumer was only called once due to idempotency + mock_provider.remove_provider_consumer.assert_called_once() def test_agent_cleanup_early_return_avoids_thread_spawn(agent): @@ -1027,8 +1027,11 @@ def test_agent_cleanup_early_return_avoids_thread_spawn(agent): mock_run_async.assert_not_called() -def test_agent__del__emits_warning_for_automatic_cleanup(agent): +def test_agent__del__emits_warning_for_automatic_cleanup(): """Test that __del__ emits warning when cleanup wasn't called manually.""" + # Create a fresh agent for this test to avoid fixture lifecycle issues + agent = Agent() + # Add a mock tool provider so cleanup will be called mock_provider = unittest.mock.MagicMock() agent.tool_registry.tool_providers = [mock_provider] diff --git a/tests/strands/experimental/tools/mcp/__init__.py b/tests/strands/experimental/tools/mcp/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py b/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py deleted file mode 100644 index 3576fd0b3..000000000 --- a/tests/strands/experimental/tools/mcp/test_mcp_tool_provider.py +++ /dev/null @@ -1,359 +0,0 @@ -"""Unit tests for MCPToolProvider.""" - -import re -from unittest.mock import MagicMock, patch - -import pytest - -from strands.experimental.tools.mcp import MCPToolProvider, ToolFilters -from strands.tools.mcp import MCPClient -from strands.tools.mcp.mcp_agent_tool import MCPAgentTool -from strands.types import PaginatedList -from strands.types.exceptions import ToolProviderException - - -@pytest.fixture -def mock_mcp_client(): - """Create a mock MCP client.""" - client = MagicMock(spec=MCPClient) - client.start = MagicMock() - client.stop = MagicMock() - client.list_tools_sync = MagicMock() - return client - - -@pytest.fixture -def mock_mcp_tool(): - """Create a mock MCP tool.""" - tool = MagicMock() - tool.name = "test_tool" - return tool - - -@pytest.fixture -def mock_agent_tool(mock_mcp_tool, mock_mcp_client): - """Create a mock MCPAgentTool.""" - agent_tool = MagicMock(spec=MCPAgentTool) - agent_tool.tool_name = "test_tool" - agent_tool.mcp_tool = mock_mcp_tool - agent_tool.mcp_client = mock_mcp_client - return agent_tool - - -def create_mock_tool(name: str) -> MagicMock: - """Helper to create mock tools with specific names.""" - tool = MagicMock(spec=MCPAgentTool) - tool.tool_name = name - tool.mcp_tool = MagicMock() - tool.mcp_tool.name = name - return tool - - -def test_init_with_client_only(mock_mcp_client): - """Test initialization with only client.""" - provider = MCPToolProvider(client=mock_mcp_client) - - assert provider._client is mock_mcp_client - assert provider._tool_filters is None - assert provider._prefix is None - assert provider._tools is None - assert provider._started is False - - -def test_init_with_all_parameters(mock_mcp_client): - """Test initialization with all parameters.""" - filters = {"allowed": ["tool1"]} - prefix = "test_prefix" - - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters, prefix=prefix) - - assert provider._client is mock_mcp_client - assert provider._tool_filters == filters - assert provider._prefix == prefix - assert provider._tools is None - assert provider._started is False - - -@pytest.mark.asyncio -async def test_load_tools_starts_client_when_not_started(mock_mcp_client, mock_agent_tool): - """Test that load_tools starts the client when not already started.""" - mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) - - provider = MCPToolProvider(client=mock_mcp_client) - - tools = await provider.load_tools() - - mock_mcp_client.start.assert_called_once() - assert provider._started is True - assert len(tools) == 1 - assert tools[0] is mock_agent_tool - - -@pytest.mark.asyncio -async def test_load_tools_does_not_start_client_when_already_started(mock_mcp_client, mock_agent_tool): - """Test that load_tools does not start client when already started.""" - mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) - - provider = MCPToolProvider(client=mock_mcp_client) - provider._started = True - - tools = await provider.load_tools() - - mock_mcp_client.start.assert_not_called() - assert len(tools) == 1 - - -@pytest.mark.asyncio -async def test_load_tools_raises_exception_on_client_start_failure(mock_mcp_client): - """Test that load_tools raises ToolProviderException when client start fails.""" - mock_mcp_client.start.side_effect = Exception("Client start failed") - - provider = MCPToolProvider(client=mock_mcp_client) - - with pytest.raises(ToolProviderException, match="Failed to start MCP client: Client start failed"): - await provider.load_tools() - - -@pytest.mark.asyncio -async def test_load_tools_caches_tools(mock_mcp_client, mock_agent_tool): - """Test that load_tools caches tools and doesn't reload them.""" - mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) - - provider = MCPToolProvider(client=mock_mcp_client) - - # First call - tools1 = await provider.load_tools() - # Second call - tools2 = await provider.load_tools() - - # Client should only be called once - mock_mcp_client.list_tools_sync.assert_called_once() - assert tools1 is tools2 - - -@pytest.mark.asyncio -async def test_load_tools_handles_pagination(mock_mcp_client, mock_agent_tool): - """Test that load_tools handles pagination correctly.""" - tool1 = MagicMock(spec=MCPAgentTool) - tool1.tool_name = "tool1" - tool2 = MagicMock(spec=MCPAgentTool) - tool2.tool_name = "tool2" - - # Mock pagination: first page returns tool1 with next token, second page returns tool2 with no token - mock_mcp_client.list_tools_sync.side_effect = [ - PaginatedList([tool1], token="page2"), - PaginatedList([tool2], token=None), - ] - - provider = MCPToolProvider(client=mock_mcp_client) - - tools = await provider.load_tools() - - # Should have called list_tools_sync twice - assert mock_mcp_client.list_tools_sync.call_count == 2 - # First call with no token, second call with "page2" token - mock_mcp_client.list_tools_sync.assert_any_call(None) - mock_mcp_client.list_tools_sync.assert_any_call("page2") - - assert len(tools) == 2 - assert tools[0] is tool1 - assert tools[1] is tool2 - - -@pytest.mark.asyncio -async def test_allowed_filter_string_match(mock_mcp_client): - """Test allowed filter with string matching.""" - tool1 = create_mock_tool("allowed_tool") - tool2 = create_mock_tool("rejected_tool") - - mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) - - filters: ToolFilters = {"allowed": ["allowed_tool"]} - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) - - tools = await provider.load_tools() - - assert len(tools) == 1 - assert tools[0].tool_name == "allowed_tool" - - -@pytest.mark.asyncio -async def test_allowed_filter_regex_match(mock_mcp_client): - """Test allowed filter with regex matching.""" - tool1 = create_mock_tool("echo_tool") - tool2 = create_mock_tool("other_tool") - - mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) - - filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) - - tools = await provider.load_tools() - - assert len(tools) == 1 - assert tools[0].tool_name == "echo_tool" - - -@pytest.mark.asyncio -async def test_allowed_filter_callable_match(mock_mcp_client): - """Test allowed filter with callable matching.""" - tool1 = create_mock_tool("short") - tool2 = create_mock_tool("very_long_tool_name") - - mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) - - def short_names_only(tool) -> bool: - return len(tool.tool_name) <= 10 - - filters: ToolFilters = {"allowed": [short_names_only]} - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) - - tools = await provider.load_tools() - - assert len(tools) == 1 - assert tools[0].tool_name == "short" - - -@pytest.mark.asyncio -async def test_rejected_filter(mock_mcp_client): - """Test rejected filter functionality.""" - tool1 = create_mock_tool("good_tool") - tool2 = create_mock_tool("bad_tool") - - mock_mcp_client.list_tools_sync.return_value = PaginatedList([tool1, tool2]) - - filters: ToolFilters = {"rejected": ["bad_tool"]} - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=filters) - - tools = await provider.load_tools() - - assert len(tools) == 1 - assert tools[0].tool_name == "good_tool" - - -@pytest.mark.asyncio -async def test_prefix_renames_tools(mock_mcp_client): - """Test that prefix properly renames tools.""" - original_tool = MagicMock(spec=MCPAgentTool) - original_tool.tool_name = "original_name" - original_tool.mcp_tool = MagicMock() - original_tool.mcp_tool.name = "original_name" - original_tool.mcp_client = mock_mcp_client - - mock_mcp_client.list_tools_sync.return_value = PaginatedList([original_tool]) - - with patch("strands.experimental.tools.mcp.mcp_tool_provider.MCPAgentTool") as mock_agent_tool_class: - new_tool = MagicMock(spec=MCPAgentTool) - new_tool.tool_name = "prefix_original_name" - mock_agent_tool_class.return_value = new_tool - - provider = MCPToolProvider(client=mock_mcp_client, prefix="prefix") - - tools = await provider.load_tools() - - # Should create new MCPAgentTool with prefixed name - mock_agent_tool_class.assert_called_once_with( - original_tool.mcp_tool, original_tool.mcp_client, agent_facing_tool_name="prefix_original_name" - ) - - assert len(tools) == 1 - assert tools[0] is new_tool - - -@pytest.mark.asyncio -async def test_cleanup_stops_client_when_started(mock_mcp_client): - """Test that cleanup stops the client when started.""" - provider = MCPToolProvider(client=mock_mcp_client) - provider._started = True - provider._tools = [MagicMock()] - - await provider.cleanup() - - mock_mcp_client.stop.assert_called_once_with(None, None, None) - assert provider._started is False - assert provider._tools is None - - -@pytest.mark.asyncio -async def test_cleanup_does_nothing_when_not_started(mock_mcp_client): - """Test that cleanup does nothing when not started.""" - provider = MCPToolProvider(client=mock_mcp_client) - provider._started = False - - await provider.cleanup() - - mock_mcp_client.stop.assert_not_called() - assert provider._started is False - - -@pytest.mark.asyncio -async def test_cleanup_raises_exception_on_client_stop_failure(mock_mcp_client): - """Test that cleanup raises ToolProviderException when client stop fails.""" - mock_mcp_client.stop.side_effect = Exception("Client stop failed") - - provider = MCPToolProvider(client=mock_mcp_client) - provider._started = True - - with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Client stop failed"): - await provider.cleanup() - - # State is not reset when cleanup fails - assert provider._started is True - assert provider._tools is None - - -@pytest.mark.asyncio -async def test_cleanup_does_not_reset_state_on_exception(mock_mcp_client): - """Test that cleanup does not reset state when exception occurs.""" - mock_mcp_client.stop.side_effect = Exception("Client stop failed") - - provider = MCPToolProvider(client=mock_mcp_client) - provider._started = True - mock_tool = MagicMock() - provider._tools = [mock_tool] - - with pytest.raises(ToolProviderException): - await provider.cleanup() - - # State should not be reset when exception occurs - assert provider._started is True - assert provider._tools == [mock_tool] - - -@pytest.mark.asyncio -async def test_load_tools_with_empty_tool_list(mock_mcp_client): - """Test load_tools with empty tool list from server.""" - mock_mcp_client.list_tools_sync.return_value = PaginatedList([]) - - provider = MCPToolProvider(client=mock_mcp_client) - - tools = await provider.load_tools() - - assert len(tools) == 0 - assert provider._started is True - - -@pytest.mark.asyncio -async def test_load_tools_with_no_filters(mock_mcp_client, mock_agent_tool): - """Test load_tools with no filters applied.""" - mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) - - provider = MCPToolProvider(client=mock_mcp_client, tool_filters=None) - - tools = await provider.load_tools() - - assert len(tools) == 1 - assert tools[0] is mock_agent_tool - - -@pytest.mark.asyncio -async def test_load_tools_with_empty_filters(mock_mcp_client, mock_agent_tool): - """Test load_tools with empty filters dict.""" - mock_mcp_client.list_tools_sync.return_value = PaginatedList([mock_agent_tool]) - - provider = MCPToolProvider(client=mock_mcp_client, tool_filters={}) - - tools = await provider.load_tools() - - assert len(tools) == 1 - assert tools[0] is mock_agent_tool diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py new file mode 100644 index 000000000..187cbeaa5 --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -0,0 +1,320 @@ +"""Unit tests for MCPClient ToolProvider functionality.""" + +import re +from unittest.mock import MagicMock, patch + +import pytest + +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_agent_tool import MCPAgentTool +from strands.tools.mcp.mcp_client import ToolFilters +from strands.types import PaginatedList +from strands.types.exceptions import ToolProviderException + + +@pytest.fixture +def mock_transport(): + """Create a mock transport callable.""" + + def transport(): + read_stream = MagicMock() + write_stream = MagicMock() + return read_stream, write_stream + + return transport + + +@pytest.fixture +def mock_mcp_tool(): + """Create a mock MCP tool.""" + tool = MagicMock() + tool.name = "test_tool" + return tool + + +@pytest.fixture +def mock_agent_tool(mock_mcp_tool): + """Create a mock MCPAgentTool.""" + agent_tool = MagicMock(spec=MCPAgentTool) + agent_tool.tool_name = "test_tool" + agent_tool.mcp_tool = mock_mcp_tool + return agent_tool + + +def create_mock_tool(name: str) -> MagicMock: + """Helper to create mock tools with specific names.""" + tool = MagicMock(spec=MCPAgentTool) + tool.tool_name = name + tool.mcp_tool = MagicMock() + tool.mcp_tool.name = name + return tool + + +def test_init_with_tool_filters_and_prefix(mock_transport): + """Test initialization with tool filters and prefix.""" + filters = {"allowed": ["tool1"]} + prefix = "test_prefix" + + client = MCPClient(mock_transport, tool_filters=filters, prefix=prefix) + + assert client._tool_filters == filters + assert client._prefix == prefix + assert client._loaded_tools is None + assert client._tool_provider_started is False + + +@pytest.mark.asyncio +async def test_load_tools_starts_client_when_not_started(mock_transport, mock_agent_tool): + """Test that load_tools starts the client when not already started.""" + client = MCPClient(mock_transport) + + with patch.object(client, "start") as mock_start, patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + tools = await client.load_tools() + + mock_start.assert_called_once() + assert client._tool_provider_started is True + assert len(tools) == 1 + assert tools[0] is mock_agent_tool + + +@pytest.mark.asyncio +async def test_load_tools_does_not_start_client_when_already_started(mock_transport, mock_agent_tool): + """Test that load_tools does not start client when already started.""" + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "start") as mock_start, patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + tools = await client.load_tools() + + mock_start.assert_not_called() + assert len(tools) == 1 + + +@pytest.mark.asyncio +async def test_load_tools_raises_exception_on_client_start_failure(mock_transport): + """Test that load_tools raises ToolProviderException when client start fails.""" + client = MCPClient(mock_transport) + + with patch.object(client, "start") as mock_start: + mock_start.side_effect = Exception("Client start failed") + + with pytest.raises(ToolProviderException, match="Failed to start MCP client: Client start failed"): + await client.load_tools() + + +@pytest.mark.asyncio +async def test_load_tools_caches_tools(mock_transport, mock_agent_tool): + """Test that load_tools caches tools and doesn't reload them.""" + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([mock_agent_tool]) + + # First call + tools1 = await client.load_tools() + # Second call + tools2 = await client.load_tools() + + # Client should only be called once + mock_list_tools.assert_called_once() + assert tools1 is tools2 + + +@pytest.mark.asyncio +async def test_load_tools_handles_pagination(mock_transport): + """Test that load_tools handles pagination correctly.""" + tool1 = create_mock_tool("tool1") + tool2 = create_mock_tool("tool2") + + client = MCPClient(mock_transport) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + # Mock pagination: first page returns tool1 with next token, second page returns tool2 with no token + mock_list_tools.side_effect = [ + PaginatedList([tool1], token="page2"), + PaginatedList([tool2], token=None), + ] + + tools = await client.load_tools() + + # Should have called list_tools_sync twice + assert mock_list_tools.call_count == 2 + # First call with no token, second call with "page2" token + mock_list_tools.assert_any_call(None) + mock_list_tools.assert_any_call("page2") + + assert len(tools) == 2 + assert tools[0] is tool1 + assert tools[1] is tool2 + + +@pytest.mark.asyncio +async def test_allowed_filter_string_match(mock_transport): + """Test allowed filter with string matching.""" + tool1 = create_mock_tool("allowed_tool") + tool2 = create_mock_tool("rejected_tool") + + filters: ToolFilters = {"allowed": ["allowed_tool"]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([tool1, tool2]) + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "allowed_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_regex_match(mock_transport): + """Test allowed filter with regex matching.""" + tool1 = create_mock_tool("echo_tool") + tool2 = create_mock_tool("other_tool") + + filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([tool1, tool2]) + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "echo_tool" + + +@pytest.mark.asyncio +async def test_allowed_filter_callable_match(mock_transport): + """Test allowed filter with callable matching.""" + tool1 = create_mock_tool("short") + tool2 = create_mock_tool("very_long_tool_name") + + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 10 + + filters: ToolFilters = {"allowed": [short_names_only]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([tool1, tool2]) + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "short" + + +@pytest.mark.asyncio +async def test_rejected_filter(mock_transport): + """Test rejected filter functionality.""" + tool1 = create_mock_tool("good_tool") + tool2 = create_mock_tool("bad_tool") + + filters: ToolFilters = {"rejected": ["bad_tool"]} + client = MCPClient(mock_transport, tool_filters=filters) + client._tool_provider_started = True + + with patch.object(client, "list_tools_sync") as mock_list_tools: + mock_list_tools.return_value = PaginatedList([tool1, tool2]) + + tools = await client.load_tools() + + assert len(tools) == 1 + assert tools[0].tool_name == "good_tool" + + +@pytest.mark.asyncio +async def test_prefix_renames_tools(mock_transport): + """Test that prefix properly renames tools.""" + original_tool = create_mock_tool("original_name") + original_tool.mcp_client = MagicMock() + + client = MCPClient(mock_transport, prefix="prefix") + client._tool_provider_started = True + + with ( + patch.object(client, "list_tools_sync") as mock_list_tools, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_list_tools.return_value = PaginatedList([original_tool]) + + new_tool = MagicMock(spec=MCPAgentTool) + new_tool.tool_name = "prefix_original_name" + mock_agent_tool_class.return_value = new_tool + + tools = await client.load_tools() + + # Should create new MCPAgentTool with prefixed name + mock_agent_tool_class.assert_called_once_with( + original_tool.mcp_tool, original_tool.mcp_client, agent_facing_tool_name="prefix_original_name" + ) + + assert len(tools) == 1 + assert tools[0] is new_tool + + +@pytest.mark.asyncio +async def test_add_provider_consumer(mock_transport): + """Test adding a provider consumer.""" + client = MCPClient(mock_transport) + + await client.add_provider_consumer("consumer1") + + assert "consumer1" in client._consumers + assert len(client._consumers) == 1 + + +@pytest.mark.asyncio +async def test_remove_provider_consumer_without_cleanup(mock_transport): + """Test removing a provider consumer without triggering cleanup.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._consumers.add("consumer2") + client._tool_provider_started = True + + await client.remove_provider_consumer("consumer1") + + assert "consumer1" not in client._consumers + assert "consumer2" in client._consumers + assert client._tool_provider_started is True # Should not cleanup yet + + +@pytest.mark.asyncio +async def test_remove_provider_consumer_with_cleanup(mock_transport): + """Test removing the last provider consumer triggers cleanup.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._tool_provider_started = True + client._loaded_tools = [MagicMock()] + + with patch.object(client, "stop") as mock_stop: + await client.remove_provider_consumer("consumer1") + + assert len(client._consumers) == 0 + assert client._tool_provider_started is False + assert client._loaded_tools is None + mock_stop.assert_called_once_with(None, None, None) + + +@pytest.mark.asyncio +async def test_remove_provider_consumer_cleanup_failure(mock_transport): + """Test that remove_provider_consumer raises ToolProviderException when cleanup fails.""" + client = MCPClient(mock_transport) + client._consumers.add("consumer1") + client._tool_provider_started = True + + with patch.object(client, "stop") as mock_stop: + mock_stop.side_effect = Exception("Cleanup failed") + + with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Cleanup failed"): + await client.remove_provider_consumer("consumer1") diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py index f9f9c9ce0..c9794326f 100644 --- a/tests/strands/tools/test_registry_tool_provider.py +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -16,6 +16,10 @@ def __init__(self, tools=None, cleanup_error=None): self._tools = tools or [] self._cleanup_error = cleanup_error self.cleanup_called = False + self.remove_consumer_called = False + self.remove_consumer_id = None + self.add_consumer_called = False + self.add_consumer_id = None async def load_tools(self): return self._tools @@ -25,6 +29,14 @@ async def cleanup(self): if self._cleanup_error: raise self._cleanup_error + async def add_provider_consumer(self, consumer_id): + self.add_consumer_called = True + self.add_consumer_id = consumer_id + + async def remove_provider_consumer(self, consumer_id): + self.remove_consumer_called = True + self.remove_consumer_id = consumer_id + class TestToolRegistryToolProvider: """Test ToolRegistry integration with ToolProvider.""" @@ -200,3 +212,85 @@ def test_tool_provider_tracking_persistence(self): assert len(registry.tool_providers) == 2 assert provider1 in registry.tool_providers assert provider2 in registry.tool_providers + + def test_process_tools_provider_async_optimization(self): + """Test that load_tools and add_provider_consumer are called in same async context.""" + mock_tool = MagicMock(spec=AgentTool) + mock_tool.tool_name = "test_tool" + + class TestProvider(ToolProvider): + def __init__(self): + self.load_tools_called = False + self.add_consumer_called = False + self.add_consumer_id = None + + async def load_tools(self): + self.load_tools_called = True + return [mock_tool] + + async def add_provider_consumer(self, consumer_id): + self.add_consumer_called = True + self.add_consumer_id = consumer_id + + async def remove_provider_consumer(self, consumer_id): + pass + + provider = TestProvider() + registry = ToolRegistry() + + # Process the provider - this should call both methods in same async context + tool_names = registry.process_tools([provider]) + + # Verify both methods were called + assert provider.load_tools_called + assert provider.add_consumer_called + assert provider.add_consumer_id == registry._registry_id + + # Verify tool was registered + assert "test_tool" in tool_names + assert provider in registry.tool_providers + + @pytest.mark.asyncio + async def test_registry_cleanup(self): + """Test that registry cleanup calls remove_provider_consumer on all providers.""" + provider1 = MockToolProvider() + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry.tool_providers = [provider1, provider2] + + await registry.cleanup_async() + + # Verify both providers had remove_provider_consumer called + assert provider1.remove_consumer_called + assert provider2.remove_consumer_called + + @pytest.mark.asyncio + async def test_registry_cleanup_with_provider_consumer_removal(self): + """Test that cleanup removes provider consumers correctly.""" + + class TestProvider(ToolProvider): + def __init__(self): + self.remove_consumer_called = False + self.remove_consumer_id = None + + async def load_tools(self): + return [] + + async def add_provider_consumer(self, consumer_id): + pass + + async def remove_provider_consumer(self, consumer_id): + self.remove_consumer_called = True + self.remove_consumer_id = consumer_id + + provider = TestProvider() + registry = ToolRegistry() + registry.tool_providers = [provider] + + # Call cleanup + await registry.cleanup_async() + + # Verify remove_provider_consumer was called with correct ID + assert provider.remove_consumer_called + assert provider.remove_consumer_id == registry._registry_id diff --git a/tests_integ/mcp/test_mcp_tool_provider.py b/tests_integ/mcp/test_mcp_tool_provider.py new file mode 100644 index 000000000..b45b38b86 --- /dev/null +++ b/tests_integ/mcp/test_mcp_tool_provider.py @@ -0,0 +1,184 @@ +"""Integration tests for MCPClient ToolProvider functionality with real MCP server.""" + +import logging +import re + +import pytest +from mcp import StdioServerParameters, stdio_client + +from strands import Agent +from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_client import ToolFilters + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger(__name__) + + +def test_mcp_client_tool_provider_filters(): + """Test MCPClient with various filter combinations.""" + + def short_names_only(tool) -> bool: + return len(tool.tool_name) <= 20 + + filters: ToolFilters = { + "allowed": ["echo", re.compile(r"echo_with_.*"), short_names_only], + "rejected": ["echo_with_delay"], + } + + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="test", + ) + + agent = Agent(tools=[client]) + tool_names = agent.tool_names + + assert "echo_with_delay" not in [name.replace("test_", "") for name in tool_names] + assert all(name.startswith("test_") for name in tool_names) + + agent.cleanup() + + +def test_mcp_client_tool_provider_execution(): + """Test that MCPClient works with agent execution.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="filtered", + ) + + agent = Agent(tools=[client]) + + assert "filtered_echo" in agent.tool_names + + tool_result = agent.tool.filtered_echo(to_echo="Hello World") + assert "Hello World" in str(tool_result) + + result = agent("Use the filtered_echo tool to echo whats inside the tags <>Integration Test") + assert "Integration Test" in str(result) + + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].call_count == 1 + assert agent.event_loop_metrics.tool_metrics["filtered_echo"].success_count == 1 + + agent.cleanup() + + +def test_mcp_client_tool_provider_reuse(): + """Test that a single MCPClient can be used across multiple agents.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="shared", + ) + + agent1 = Agent(tools=[client]) + assert "shared_echo" in agent1.tool_names + + result1 = agent1.tool.shared_echo(to_echo="Agent 1") + assert "Agent 1" in str(result1) + + agent2 = Agent(tools=[client]) + assert "shared_echo" in agent2.tool_names + + result2 = agent2.tool.shared_echo(to_echo="Agent 2") + assert "Agent 2" in str(result2) + + assert len(agent1.tool_names) == len(agent2.tool_names) + assert agent1.tool_names == agent2.tool_names + + agent1.cleanup() + agent2.cleanup() + + +def test_mcp_client_reference_counting(): + """Test that MCPClient uses reference counting - cleanup only happens when last consumer is removed.""" + filters: ToolFilters = {"allowed": ["echo"]} + client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters=filters, + prefix="ref", + ) + + # Create two agents with the same client + agent1 = Agent(tools=[client]) + agent2 = Agent(tools=[client]) + + # Both should have the tool + assert "ref_echo" in agent1.tool_names + assert "ref_echo" in agent2.tool_names + + # Agent 1 uses the tool + result1 = agent1.tool.ref_echo(to_echo="Agent 1 Test") + assert "Agent 1 Test" in str(result1) + + # Agent 1 cleans up - client should still be active for agent 2 + agent1.cleanup() + + # Agent 2 should still be able to use the tool + result2 = agent2.tool.ref_echo(to_echo="Agent 2 Test") + assert "Agent 2 Test" in str(result2) + + # Agent 2 cleans up - now client should be fully cleaned up + agent2.cleanup() + + +def test_mcp_client_multiple_servers(): + """Test MCPClient with multiple MCP servers simultaneously.""" + client1 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters={"allowed": ["echo"]}, + prefix="server1", + ) + client2 = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), + tool_filters={"allowed": ["echo_with_structured_content"]}, + prefix="server2", + ) + + agent = Agent(tools=[client1, client2]) + + assert "server1_echo" in agent.tool_names + assert "server2_echo_with_structured_content" in agent.tool_names + assert len(agent.tool_names) == 2 + + result1 = agent.tool.server1_echo(to_echo="From Server 1") + assert "From Server 1" in str(result1) + + result2 = agent.tool.server2_echo_with_structured_content(to_echo="From Server 2") + assert "From Server 2" in str(result2) + + agent.cleanup() + + +def test_mcp_client_server_startup_failure(): + """Test that MCPClient handles server startup failure gracefully without hanging.""" + from strands.types.exceptions import ToolProviderException + + failing_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="nonexistent_command", args=["--invalid"])), + startup_timeout=2, + ) + + with pytest.raises(ValueError, match="Failed to load tool") as exc_info: + Agent(tools=[failing_client]) + + assert isinstance(exc_info.value.__cause__, ToolProviderException) + + +def test_mcp_client_server_connection_timeout(): + """Test that MCPClient times out gracefully when server hangs during startup.""" + from strands.types.exceptions import ToolProviderException + + hanging_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="sleep", args=["10"])), + startup_timeout=1, + ) + + with pytest.raises(ValueError, match="Failed to load tool") as exc_info: + Agent(tools=[hanging_client]) + + assert isinstance(exc_info.value.__cause__, ToolProviderException) diff --git a/tests_integ/test_mcp_tool_provider.py b/tests_integ/test_mcp_tool_provider.py deleted file mode 100644 index 5b7bb3ed1..000000000 --- a/tests_integ/test_mcp_tool_provider.py +++ /dev/null @@ -1,171 +0,0 @@ -"""Integration tests for MCPToolProvider with real MCP server.""" - -import logging -import re - -import pytest -from mcp import StdioServerParameters, stdio_client - -from strands import Agent -from strands.experimental.tools.mcp import MCPToolProvider, ToolFilters -from strands.tools.mcp import MCPClient -from strands.types.exceptions import ToolProviderException - -logging.basicConfig(level=logging.DEBUG) - -logger = logging.getLogger(__name__) - - -def test_mcp_tool_provider_filters(): - """Test MCPToolProvider with various filter combinations.""" - stdio_mcp_client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) - ) - - # Test string filter, regex filter, callable filter, and prefix - def short_names_only(tool) -> bool: - return len(tool.tool_name) <= 20 # Allow most tools - - filters: ToolFilters = { - "allowed": ["echo", re.compile(r"echo_with_.*"), short_names_only], - "rejected": ["echo_with_delay"], - "max_tools": 2, - } - - provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, prefix="test") - agent = Agent(tools=[provider]) - tool_names = agent.tool_names - - # Should have 2 tools max, with test_ prefix, no delay tool - assert len(tool_names) == 2 - assert "echo_with_delay" not in [name.replace("test_", "") for name in tool_names] - assert all(name.startswith("test_") for name in tool_names) - - agent.cleanup() - - -def test_mcp_tool_provider_execution(): - """Test that MCPToolProvider works with agent execution.""" - stdio_mcp_client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) - ) - - filters: ToolFilters = {"allowed": ["echo"]} - provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, prefix="filtered") - agent = Agent( - tools=[provider], - ) - - # Verify the filtered tool exists - assert "filtered_echo" in agent.tool_names - - # # Test direct tool call to verify it works (use correct parameter name from echo server) - tool_result = agent.tool.filtered_echo(to_echo="Hello World") - assert "Hello World" in str(tool_result) - - # # Test agent execution using the tool - result = agent("Use the filtered_echo tool to echo whats inside the tags <>Integration Test") - assert "Integration Test" in str(result) - - assert agent.event_loop_metrics.tool_metrics["filtered_echo"].call_count == 1 - assert agent.event_loop_metrics.tool_metrics["filtered_echo"].success_count == 1 - - agent.cleanup() - - -def test_mcp_tool_provider_reuse(): - """Test that a single MCPToolProvider can be used across multiple agents.""" - stdio_mcp_client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) - ) - - filters: ToolFilters = {"allowed": ["echo"]} - provider = MCPToolProvider(client=stdio_mcp_client, tool_filters=filters, prefix="shared") - - # Create first agent with the provider - agent1 = Agent(tools=[provider]) - assert "shared_echo" in agent1.tool_names - - # Test first agent (use correct parameter name from echo server) - result1 = agent1.tool.shared_echo(to_echo="Agent 1") - assert "Agent 1" in str(result1) - - # Create second agent with the same provider - agent2 = Agent(tools=[provider]) - assert "shared_echo" in agent2.tool_names - - # Test second agent (use correct parameter name from echo server) - result2 = agent2.tool.shared_echo(to_echo="Agent 2") - assert "Agent 2" in str(result2) - - # Both agents should have the same tool count - assert len(agent1.tool_names) == len(agent2.tool_names) - assert agent1.tool_names == agent2.tool_names - - agent1.cleanup() - agent2.cleanup() - - -def test_mcp_tool_provider_multiple_servers(): - """Test MCPToolProvider with multiple MCP servers simultaneously.""" - # Create two separate MCP clients - client1 = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) - ) - client2 = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) - ) - - # Create providers with different prefixes - provider1 = MCPToolProvider(client=client1, tool_filters={"allowed": ["echo"]}, prefix="server1") - # Use correct tool name from echo_server.py - provider2 = MCPToolProvider( - client=client2, tool_filters={"allowed": ["echo_with_structured_content"]}, prefix="server2" - ) - - # Create agent with both providers - agent = Agent(tools=[provider1, provider2]) - - # Should have tools from both servers with different prefixes - assert "server1_echo" in agent.tool_names - assert "server2_echo_with_structured_content" in agent.tool_names - assert len(agent.tool_names) == 2 - - # Test tools from both servers work - result1 = agent.tool.server1_echo(to_echo="From Server 1") - assert "From Server 1" in str(result1) - - result2 = agent.tool.server2_echo_with_structured_content(to_echo="From Server 2") - assert "From Server 2" in str(result2) - - agent.cleanup() - - -def test_mcp_tool_provider_server_startup_failure(): - """Test that MCPToolProvider handles server startup failure gracefully without hanging.""" - # Create client with invalid command that will fail to start - failing_client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="nonexistent_command", args=["--invalid"])), - startup_timeout=2, # Short timeout to avoid hanging - ) - - provider = MCPToolProvider(client=failing_client) - - # Should raise ToolProviderException when trying to load tools - with pytest.raises(ToolProviderException, match="Failed to start MCP client"): - Agent(tools=[provider]) - - -def test_mcp_tool_provider_server_connection_timeout(): - """Test that MCPToolProvider times out gracefully when server hangs during startup.""" - # Create client that will hang during connection - hanging_client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="sleep", args=["10"])), # Sleep for 10 seconds - startup_timeout=1, # 1 second timeout - ) - - provider = MCPToolProvider(client=hanging_client) - - # Should raise ToolProviderException due to timeout - with pytest.raises(ToolProviderException, match="Failed to start MCP client"): - Agent(tools=[provider]) From 6cba7d7b6b988c1a7f6f0e664c800f83a1c036ef Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 8 Oct 2025 17:54:41 -0400 Subject: [PATCH 04/35] fix code coverage skip --- .codecov.yml | 3 +++ src/strands/tools/mcp/mcp_agent_tool.py | 6 +++--- src/strands/tools/mcp/mcp_client.py | 14 +++++++++----- .../tools/mcp/test_mcp_client_tool_provider.py | 6 +++--- 4 files changed, 18 insertions(+), 11 deletions(-) create mode 100644 .codecov.yml diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 000000000..36be0c484 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,3 @@ +coverage: + ignore: + - "src/strands/experimental/tools/mcp/mcp_tool_provider.py" diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index 91ec6216a..af0c069a1 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -28,20 +28,20 @@ class MCPAgentTool(AgentTool): seamlessly within the agent framework. """ - def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", agent_facing_tool_name: str | None = None) -> None: + def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: str | None = None) -> None: """Initialize a new MCPAgentTool instance. Args: mcp_tool: The MCP tool to adapt mcp_client: The MCP server connection to use for tool invocation - agent_facing_tool_name: Optional name to use for the agent tool (for disambiguation) + name_override: Optional name to use for the agent tool (for disambiguation) If None, uses the original MCP tool name """ super().__init__() logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) self.mcp_tool = mcp_tool self.mcp_client = mcp_client - self._agent_tool_name = agent_facing_tool_name or mcp_tool.name + self._agent_tool_name = name_override or mcp_tool.name @property def tool_name(self) -> str: diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index b2d5887c3..b2717058f 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -25,7 +25,7 @@ from mcp.types import EmbeddedResource as MCPEmbeddedResource from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent -from typing_extensions import TypedDict +from typing_extensions import Protocol, TypedDict from ...experimental.tools import ToolProvider from ...types import PaginatedList @@ -40,8 +40,12 @@ T = TypeVar("T") -_ToolFilterCallback = Callable[[AgentTool], bool] -_ToolFilterPattern = Union[str, Pattern[str], _ToolFilterCallback] + +class _ToolFilterCallback(Protocol): + def __call__(self, tool: AgentTool, **kwargs: Any) -> bool: ... + + +_ToolFilterPattern = str | Pattern[str] | _ToolFilterCallback class ToolFilters(TypedDict, total=False): @@ -676,7 +680,7 @@ def _apply_prefix(self, tool: MCPAgentTool) -> MCPAgentTool: # Create new tool with prefixed agent name but preserve original MCP name old_name = tool.tool_name new_agent_name = f"{self._prefix}_{tool.mcp_tool.name}" - new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, agent_facing_tool_name=new_agent_name) + new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, name_override=new_agent_name) logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name) return new_tool @@ -686,7 +690,7 @@ def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPatter if callable(pattern): if pattern(tool): return True - elif hasattr(pattern, "match") and hasattr(pattern, "pattern"): + elif isinstance(pattern, Pattern): if pattern.match(tool.tool_name): return True elif isinstance(pattern, str): diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py index 187cbeaa5..59a9eb81a 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -215,8 +215,8 @@ def short_names_only(tool) -> bool: @pytest.mark.asyncio -async def test_rejected_filter(mock_transport): - """Test rejected filter functionality.""" +async def test_rejected_filter_string_match(mock_transport): + """Test rejected filter with string matching.""" tool1 = create_mock_tool("good_tool") tool2 = create_mock_tool("bad_tool") @@ -256,7 +256,7 @@ async def test_prefix_renames_tools(mock_transport): # Should create new MCPAgentTool with prefixed name mock_agent_tool_class.assert_called_once_with( - original_tool.mcp_tool, original_tool.mcp_client, agent_facing_tool_name="prefix_original_name" + original_tool.mcp_tool, original_tool.mcp_client, name_override="prefix_original_name" ) assert len(tools) == 1 From 8f39f8b87d1a6be96f335a4caf18f722febdf219 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 9 Oct 2025 13:58:30 -0400 Subject: [PATCH 05/35] comments --- .codecov.yml | 2 +- src/strands/agent/agent.py | 7 ++- .../experimental/tools/tool_provider.py | 4 +- src/strands/tools/mcp/mcp_client.py | 4 +- src/strands/tools/registry.py | 4 +- tests/strands/agent/test_agent.py | 55 ++++++++++--------- .../mcp/test_mcp_client_tool_provider.py | 18 +++--- .../tools/test_registry_tool_provider.py | 20 +++---- 8 files changed, 59 insertions(+), 55 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index 36be0c484..866a0af3a 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,3 +1,3 @@ coverage: ignore: - - "src/strands/experimental/tools/mcp/mcp_tool_provider.py" + - "src/strands/experimental/tools/tool_provider.py" # This is an interface, cannot meaningfully cover diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index ee73f2200..cf01a6645 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -575,10 +575,11 @@ def __del__(self) -> None: if self._cleanup_called or not self.tool_registry.tool_providers: return - logger.warning( - "agent_id=<%s> | Agent cleanup called via __del__. " + warnings.warn( + f"agent_id={self.agent_id} | Agent cleanup called via __del__. " "Consider calling agent.cleanup() explicitly for better resource management.", - self.agent_id, + ResourceWarning, + stacklevel=2, ) self.cleanup() except Exception as e: diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py index 5023dd72c..401555368 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/experimental/tools/tool_provider.py @@ -27,7 +27,7 @@ async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]: pass @abstractmethod - async def add_provider_consumer(self, id: Any, **kwargs: Any) -> None: + async def add_consumer(self, id: Any, **kwargs: Any) -> None: """Add a consumer to this tool provider. Args: @@ -37,7 +37,7 @@ async def add_provider_consumer(self, id: Any, **kwargs: Any) -> None: pass @abstractmethod - async def remove_provider_consumer(self, id: Any, **kwargs: Any) -> None: + async def remove_consumer(self, id: Any, **kwargs: Any) -> None: """Remove a consumer from this tool provider. Args: diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index b2717058f..02df09190 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -237,12 +237,12 @@ async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: return self._loaded_tools - async def add_provider_consumer(self, id: Any, **kwargs: Any) -> None: + async def add_consumer(self, id: Any, **kwargs: Any) -> None: """Add a consumer to this tool provider.""" self._consumers.add(id) logger.debug("added provider consumer, count=%d", len(self._consumers)) - async def remove_provider_consumer(self, id: Any, **kwargs: Any) -> None: + async def remove_consumer(self, id: Any, **kwargs: Any) -> None: """Remove a consumer from this tool provider.""" self._consumers.discard(id) logger.debug("removed provider consumer, count=%d", len(self._consumers)) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 52028ee32..39fd508d0 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -130,7 +130,7 @@ def add_tool(tool: Any) -> None: async def get_tools_and_register_consumer() -> Sequence[AgentTool]: provider_tools = await tool.load_tools() - await tool.add_provider_consumer(self._registry_id) + await tool.add_consumer(self._registry_id) return provider_tools provider_tools = run_async(get_tools_and_register_consumer) @@ -665,7 +665,7 @@ async def cleanup_async(self, **kwargs: Any) -> None: """Clean up all tool providers in this registry.""" for provider in self.tool_providers: try: - await provider.remove_provider_consumer(self._registry_id) + await provider.remove_consumer(self._registry_id) logger.debug("provider=<%s> | removed provider consumer", type(provider).__name__) except Exception as e: logger.warning( diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index b342af131..0ff122a72 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -914,15 +914,15 @@ async def test_agent_cleanup_async(agent): """Test that agent cleanup_async method works correctly.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] await agent.cleanup_async() - # Verify provider remove_provider_consumer was called - mock_provider.remove_provider_consumer.assert_called_once_with(agent.tool_registry._registry_id) + # Verify provider remove_consumer was called + mock_provider.remove_consumer.assert_called_once_with(agent.tool_registry._registry_id) # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -932,9 +932,9 @@ async def test_agent_cleanup_async_handles_exceptions(agent): """Test that agent cleanup_async handles exceptions gracefully.""" # Create mock tool providers, one that raises an exception mock_provider1 = unittest.mock.MagicMock() - mock_provider1.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider1.remove_consumer = unittest.mock.AsyncMock() mock_provider2 = unittest.mock.MagicMock() - mock_provider2.remove_provider_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) + mock_provider2.remove_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) # Add providers to agent's tool registry agent.tool_registry.tool_providers = [mock_provider1, mock_provider2] @@ -943,8 +943,8 @@ async def test_agent_cleanup_async_handles_exceptions(agent): await agent.cleanup_async() # Verify both providers were attempted - mock_provider1.remove_provider_consumer.assert_called_once() - mock_provider2.remove_provider_consumer.assert_called_once() + mock_provider1.remove_consumer.assert_called_once() + mock_provider2.remove_consumer.assert_called_once() # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -954,7 +954,7 @@ async def test_agent_cleanup_async_idempotent(agent): """Test that calling cleanup_async multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -963,8 +963,8 @@ async def test_agent_cleanup_async_idempotent(agent): await agent.cleanup_async() await agent.cleanup_async() - # Verify provider remove_provider_consumer was only called once due to idempotency - mock_provider.remove_provider_consumer.assert_called_once() + # Verify provider remove_consumer was only called once due to idempotency + mock_provider.remove_consumer.assert_called_once() @pytest.mark.asyncio @@ -1002,7 +1002,7 @@ def test_agent_cleanup_idempotent(agent): """Test that calling cleanup multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -1011,8 +1011,8 @@ def test_agent_cleanup_idempotent(agent): agent.cleanup() agent.cleanup() - # Verify provider remove_provider_consumer was only called once due to idempotency - mock_provider.remove_provider_consumer.assert_called_once() + # Verify provider remove_consumer was only called once due to idempotency + mock_provider.remove_consumer.assert_called_once() def test_agent_cleanup_early_return_avoids_thread_spawn(agent): @@ -1036,14 +1036,15 @@ def test_agent__del__emits_warning_for_automatic_cleanup(): mock_provider = unittest.mock.MagicMock() agent.tool_registry.tool_providers = [mock_provider] - with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") agent.__del__() - # Verify warning was logged - mock_logger.warning.assert_called_once() - warning_call = mock_logger.warning.call_args[0] - assert "Agent cleanup called via __del__" in warning_call[0] + # Verify warning was emitted + assert len(w) == 1 + assert issubclass(w[0].category, ResourceWarning) + assert "Agent cleanup called via __del__" in str(w[0].message) # Verify cleanup was called mock_cleanup.assert_called_once() @@ -1059,11 +1060,12 @@ def test_agent__del__no_warning_after_manual_cleanup(): with unittest.mock.patch.object(agent, "cleanup_async"): agent.cleanup() - with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") agent.__del__() - # Verify no warning was logged - mock_logger.warning.assert_not_called() + # Verify no warning was emitted + assert len(w) == 0 def test_agent__del__no_warning_when_no_tool_providers(): @@ -1076,12 +1078,13 @@ def test_agent__del__no_warning_when_no_tool_providers(): # Ensure no tool providers agent.tool_registry.tool_providers = [] - with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") agent.__del__() - # Verify no warning was logged and cleanup wasn't called - mock_logger.warning.assert_not_called() + # Verify no warning was emitted and cleanup wasn't called + assert len(w) == 0 mock_cleanup.assert_not_called() diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py index 59a9eb81a..094cc05b1 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -264,25 +264,25 @@ async def test_prefix_renames_tools(mock_transport): @pytest.mark.asyncio -async def test_add_provider_consumer(mock_transport): +async def test_add_consumer(mock_transport): """Test adding a provider consumer.""" client = MCPClient(mock_transport) - await client.add_provider_consumer("consumer1") + await client.add_consumer("consumer1") assert "consumer1" in client._consumers assert len(client._consumers) == 1 @pytest.mark.asyncio -async def test_remove_provider_consumer_without_cleanup(mock_transport): +async def test_remove_consumer_without_cleanup(mock_transport): """Test removing a provider consumer without triggering cleanup.""" client = MCPClient(mock_transport) client._consumers.add("consumer1") client._consumers.add("consumer2") client._tool_provider_started = True - await client.remove_provider_consumer("consumer1") + await client.remove_consumer("consumer1") assert "consumer1" not in client._consumers assert "consumer2" in client._consumers @@ -290,7 +290,7 @@ async def test_remove_provider_consumer_without_cleanup(mock_transport): @pytest.mark.asyncio -async def test_remove_provider_consumer_with_cleanup(mock_transport): +async def test_remove_consumer_with_cleanup(mock_transport): """Test removing the last provider consumer triggers cleanup.""" client = MCPClient(mock_transport) client._consumers.add("consumer1") @@ -298,7 +298,7 @@ async def test_remove_provider_consumer_with_cleanup(mock_transport): client._loaded_tools = [MagicMock()] with patch.object(client, "stop") as mock_stop: - await client.remove_provider_consumer("consumer1") + await client.remove_consumer("consumer1") assert len(client._consumers) == 0 assert client._tool_provider_started is False @@ -307,8 +307,8 @@ async def test_remove_provider_consumer_with_cleanup(mock_transport): @pytest.mark.asyncio -async def test_remove_provider_consumer_cleanup_failure(mock_transport): - """Test that remove_provider_consumer raises ToolProviderException when cleanup fails.""" +async def test_remove_consumer_cleanup_failure(mock_transport): + """Test that remove_consumer raises ToolProviderException when cleanup fails.""" client = MCPClient(mock_transport) client._consumers.add("consumer1") client._tool_provider_started = True @@ -317,4 +317,4 @@ async def test_remove_provider_consumer_cleanup_failure(mock_transport): mock_stop.side_effect = Exception("Cleanup failed") with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Cleanup failed"): - await client.remove_provider_consumer("consumer1") + await client.remove_consumer("consumer1") diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py index c9794326f..ca10862dc 100644 --- a/tests/strands/tools/test_registry_tool_provider.py +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -29,11 +29,11 @@ async def cleanup(self): if self._cleanup_error: raise self._cleanup_error - async def add_provider_consumer(self, consumer_id): + async def add_consumer(self, consumer_id): self.add_consumer_called = True self.add_consumer_id = consumer_id - async def remove_provider_consumer(self, consumer_id): + async def remove_consumer(self, consumer_id): self.remove_consumer_called = True self.remove_consumer_id = consumer_id @@ -214,7 +214,7 @@ def test_tool_provider_tracking_persistence(self): assert provider2 in registry.tool_providers def test_process_tools_provider_async_optimization(self): - """Test that load_tools and add_provider_consumer are called in same async context.""" + """Test that load_tools and add_consumer are called in same async context.""" mock_tool = MagicMock(spec=AgentTool) mock_tool.tool_name = "test_tool" @@ -228,11 +228,11 @@ async def load_tools(self): self.load_tools_called = True return [mock_tool] - async def add_provider_consumer(self, consumer_id): + async def add_consumer(self, consumer_id): self.add_consumer_called = True self.add_consumer_id = consumer_id - async def remove_provider_consumer(self, consumer_id): + async def remove_consumer(self, consumer_id): pass provider = TestProvider() @@ -252,7 +252,7 @@ async def remove_provider_consumer(self, consumer_id): @pytest.mark.asyncio async def test_registry_cleanup(self): - """Test that registry cleanup calls remove_provider_consumer on all providers.""" + """Test that registry cleanup calls remove_consumer on all providers.""" provider1 = MockToolProvider() provider2 = MockToolProvider() @@ -261,7 +261,7 @@ async def test_registry_cleanup(self): await registry.cleanup_async() - # Verify both providers had remove_provider_consumer called + # Verify both providers had remove_consumer called assert provider1.remove_consumer_called assert provider2.remove_consumer_called @@ -277,10 +277,10 @@ def __init__(self): async def load_tools(self): return [] - async def add_provider_consumer(self, consumer_id): + async def add_consumer(self, consumer_id): pass - async def remove_provider_consumer(self, consumer_id): + async def remove_consumer(self, consumer_id): self.remove_consumer_called = True self.remove_consumer_id = consumer_id @@ -291,6 +291,6 @@ async def remove_provider_consumer(self, consumer_id): # Call cleanup await registry.cleanup_async() - # Verify remove_provider_consumer was called with correct ID + # Verify remove_consumer was called with correct ID assert provider.remove_consumer_called assert provider.remove_consumer_id == registry._registry_id From abe5575a49540e80a05d37c5762a601c7805efe5 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 9 Oct 2025 14:30:47 -0400 Subject: [PATCH 06/35] comments --- src/strands/agent/agent.py | 8 ++--- tests/strands/agent/test_agent.py | 56 ++++++++++++++----------------- 2 files changed, 29 insertions(+), 35 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index cf01a6645..46c5390b1 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -12,7 +12,6 @@ import json import logging import random -import warnings from typing import ( Any, AsyncGenerator, @@ -575,11 +574,10 @@ def __del__(self) -> None: if self._cleanup_called or not self.tool_registry.tool_providers: return - warnings.warn( - f"agent_id={self.agent_id} | Agent cleanup called via __del__. " + logger.warning( + "agent_id=<%s> | Agent cleanup called via __del__. " "Consider calling agent.cleanup() explicitly for better resource management.", - ResourceWarning, - stacklevel=2, + self.agent_id, ) self.cleanup() except Exception as e: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 0ff122a72..386c14635 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -4,7 +4,6 @@ import os import textwrap import unittest.mock -import warnings from uuid import uuid4 import pytest @@ -914,15 +913,15 @@ async def test_agent_cleanup_async(agent): """Test that agent cleanup_async method works correctly.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_consumer = unittest.mock.AsyncMock() + mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] await agent.cleanup_async() - # Verify provider remove_consumer was called - mock_provider.remove_consumer.assert_called_once_with(agent.tool_registry._registry_id) + # Verify provider remove_provider_consumer was called + mock_provider.remove_provider_consumer.assert_called_once_with(agent.tool_registry._registry_id) # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -932,9 +931,9 @@ async def test_agent_cleanup_async_handles_exceptions(agent): """Test that agent cleanup_async handles exceptions gracefully.""" # Create mock tool providers, one that raises an exception mock_provider1 = unittest.mock.MagicMock() - mock_provider1.remove_consumer = unittest.mock.AsyncMock() + mock_provider1.remove_provider_consumer = unittest.mock.AsyncMock() mock_provider2 = unittest.mock.MagicMock() - mock_provider2.remove_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) + mock_provider2.remove_provider_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) # Add providers to agent's tool registry agent.tool_registry.tool_providers = [mock_provider1, mock_provider2] @@ -943,8 +942,8 @@ async def test_agent_cleanup_async_handles_exceptions(agent): await agent.cleanup_async() # Verify both providers were attempted - mock_provider1.remove_consumer.assert_called_once() - mock_provider2.remove_consumer.assert_called_once() + mock_provider1.remove_provider_consumer.assert_called_once() + mock_provider2.remove_provider_consumer.assert_called_once() # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -954,7 +953,7 @@ async def test_agent_cleanup_async_idempotent(agent): """Test that calling cleanup_async multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_consumer = unittest.mock.AsyncMock() + mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -963,8 +962,8 @@ async def test_agent_cleanup_async_idempotent(agent): await agent.cleanup_async() await agent.cleanup_async() - # Verify provider remove_consumer was only called once due to idempotency - mock_provider.remove_consumer.assert_called_once() + # Verify provider remove_provider_consumer was only called once due to idempotency + mock_provider.remove_provider_consumer.assert_called_once() @pytest.mark.asyncio @@ -1002,7 +1001,7 @@ def test_agent_cleanup_idempotent(agent): """Test that calling cleanup multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_consumer = unittest.mock.AsyncMock() + mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -1011,8 +1010,8 @@ def test_agent_cleanup_idempotent(agent): agent.cleanup() agent.cleanup() - # Verify provider remove_consumer was only called once due to idempotency - mock_provider.remove_consumer.assert_called_once() + # Verify provider remove_provider_consumer was only called once due to idempotency + mock_provider.remove_provider_consumer.assert_called_once() def test_agent_cleanup_early_return_avoids_thread_spawn(agent): @@ -1036,15 +1035,14 @@ def test_agent__del__emits_warning_for_automatic_cleanup(): mock_provider = unittest.mock.MagicMock() agent.tool_registry.tool_providers = [mock_provider] - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: agent.__del__() - # Verify warning was emitted - assert len(w) == 1 - assert issubclass(w[0].category, ResourceWarning) - assert "Agent cleanup called via __del__" in str(w[0].message) + # Verify warning was logged + mock_logger.warning.assert_called_once() + warning_call = mock_logger.warning.call_args[0] + assert "Agent cleanup called via __del__" in warning_call[0] # Verify cleanup was called mock_cleanup.assert_called_once() @@ -1060,12 +1058,11 @@ def test_agent__del__no_warning_after_manual_cleanup(): with unittest.mock.patch.object(agent, "cleanup_async"): agent.cleanup() - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: agent.__del__() - # Verify no warning was emitted - assert len(w) == 0 + # Verify no warning was logged + mock_logger.warning.assert_not_called() def test_agent__del__no_warning_when_no_tool_providers(): @@ -1078,13 +1075,12 @@ def test_agent__del__no_warning_when_no_tool_providers(): # Ensure no tool providers agent.tool_registry.tool_providers = [] - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: agent.__del__() - # Verify no warning was emitted and cleanup wasn't called - assert len(w) == 0 + # Verify no warning was logged and cleanup wasn't called + mock_logger.warning.assert_not_called() mock_cleanup.assert_not_called() From c804208b87ae12de2bc3334c07216b036f79b488 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 9 Oct 2025 14:46:28 -0400 Subject: [PATCH 07/35] comments --- tests/strands/agent/test_agent.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 386c14635..18074ce28 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -913,15 +913,15 @@ async def test_agent_cleanup_async(agent): """Test that agent cleanup_async method works correctly.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] await agent.cleanup_async() - # Verify provider remove_provider_consumer was called - mock_provider.remove_provider_consumer.assert_called_once_with(agent.tool_registry._registry_id) + # Verify provider remove_consumer was called + mock_provider.remove_consumer.assert_called_once_with(agent.tool_registry._registry_id) # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -931,9 +931,9 @@ async def test_agent_cleanup_async_handles_exceptions(agent): """Test that agent cleanup_async handles exceptions gracefully.""" # Create mock tool providers, one that raises an exception mock_provider1 = unittest.mock.MagicMock() - mock_provider1.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider1.remove_consumer = unittest.mock.AsyncMock() mock_provider2 = unittest.mock.MagicMock() - mock_provider2.remove_provider_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) + mock_provider2.remove_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) # Add providers to agent's tool registry agent.tool_registry.tool_providers = [mock_provider1, mock_provider2] @@ -942,8 +942,8 @@ async def test_agent_cleanup_async_handles_exceptions(agent): await agent.cleanup_async() # Verify both providers were attempted - mock_provider1.remove_provider_consumer.assert_called_once() - mock_provider2.remove_provider_consumer.assert_called_once() + mock_provider1.remove_consumer.assert_called_once() + mock_provider2.remove_consumer.assert_called_once() # Verify cleanup was marked as called assert agent._cleanup_called is True @@ -953,7 +953,7 @@ async def test_agent_cleanup_async_idempotent(agent): """Test that calling cleanup_async multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -962,8 +962,8 @@ async def test_agent_cleanup_async_idempotent(agent): await agent.cleanup_async() await agent.cleanup_async() - # Verify provider remove_provider_consumer was only called once due to idempotency - mock_provider.remove_provider_consumer.assert_called_once() + # Verify provider remove_consumer was only called once due to idempotency + mock_provider.remove_consumer.assert_called_once() @pytest.mark.asyncio @@ -1001,7 +1001,7 @@ def test_agent_cleanup_idempotent(agent): """Test that calling cleanup multiple times is safe.""" # Create mock tool provider mock_provider = unittest.mock.MagicMock() - mock_provider.remove_provider_consumer = unittest.mock.AsyncMock() + mock_provider.remove_consumer = unittest.mock.AsyncMock() # Add provider to agent's tool registry agent.tool_registry.tool_providers = [mock_provider] @@ -1010,8 +1010,8 @@ def test_agent_cleanup_idempotent(agent): agent.cleanup() agent.cleanup() - # Verify provider remove_provider_consumer was only called once due to idempotency - mock_provider.remove_provider_consumer.assert_called_once() + # Verify provider remove_consumer was only called once due to idempotency + mock_provider.remove_consumer.assert_called_once() def test_agent_cleanup_early_return_avoids_thread_spawn(agent): From 3546ad4145cc1bcc9e048a34c126d58fc72e967b Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 15 Oct 2025 10:59:23 -0400 Subject: [PATCH 08/35] comments --- src/strands/agent/agent.py | 8 - src/strands/tools/registry.py | 6 +- tests/fixtures/mock_agent_tool.py | 35 +++ tests/strands/agent/test_agent.py | 4 +- .../tools/test_registry_tool_provider.py | 199 +++++++++--------- 5 files changed, 138 insertions(+), 114 deletions(-) create mode 100644 tests/fixtures/mock_agent_tool.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 46c5390b1..f75bcd948 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -571,14 +571,6 @@ def __del__(self) -> None: This serves as a fallback cleanup mechanism, but explicit cleanup() is preferred. """ try: - if self._cleanup_called or not self.tool_registry.tool_providers: - return - - logger.warning( - "agent_id=<%s> | Agent cleanup called via __del__. " - "Consider calling agent.cleanup() explicitly for better resource management.", - self.agent_id, - ) self.cleanup() except Exception as e: # Log exceptions during garbage collection cleanup for debugging diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 39fd508d0..b9f861f60 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -39,7 +39,7 @@ def __init__(self) -> None: self.registry: Dict[str, AgentTool] = {} self.dynamic_tools: Dict[str, AgentTool] = {} self.tool_config: Optional[Dict[str, Any]] = None - self.tool_providers: List[ToolProvider] = [] + self._tool_providers: List[ToolProvider] = [] self._registry_id = str(uuid.uuid4()) def process_tools(self, tools: List[Any]) -> List[str]: @@ -126,7 +126,7 @@ def add_tool(tool: Any) -> None: # Case 5: ToolProvider elif isinstance(tool, ToolProvider): - self.tool_providers.append(tool) + self._tool_providers.append(tool) async def get_tools_and_register_consumer() -> Sequence[AgentTool]: provider_tools = await tool.load_tools() @@ -663,7 +663,7 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: async def cleanup_async(self, **kwargs: Any) -> None: """Clean up all tool providers in this registry.""" - for provider in self.tool_providers: + for provider in self._tool_providers: try: await provider.remove_consumer(self._registry_id) logger.debug("provider=<%s> | removed provider consumer", type(provider).__name__) diff --git a/tests/fixtures/mock_agent_tool.py b/tests/fixtures/mock_agent_tool.py new file mode 100644 index 000000000..5d11bbdb8 --- /dev/null +++ b/tests/fixtures/mock_agent_tool.py @@ -0,0 +1,35 @@ +from typing import Any + +import pytest + +from strands.types.tools import AgentTool, ToolSpec +from strands.types.content import ToolUse + + +class MockAgentTool(AgentTool): + """Mock AgentTool implementation for testing.""" + + def __init__(self, name: str): + super().__init__() + self._tool_name = name + + @property + def tool_name(self) -> str: + return self._tool_name + + @property + def tool_spec(self) -> ToolSpec: + return ToolSpec(name=self._tool_name, description="Mock tool", input_schema={}) + + @property + def tool_type(self) -> str: + return "mock" + + def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any): + yield f"Mock result for {self._tool_name}" + + +@pytest.fixture +def mock_agent_tool(): + """Fixture factory for creating MockAgentTool instances.""" + return MockAgentTool \ No newline at end of file diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 18074ce28..9e150dbd3 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1033,7 +1033,7 @@ def test_agent__del__emits_warning_for_automatic_cleanup(): # Add a mock tool provider so cleanup will be called mock_provider = unittest.mock.MagicMock() - agent.tool_registry.tool_providers = [mock_provider] + agent.tool_registry._tool_providers = [mock_provider] with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: @@ -1073,7 +1073,7 @@ def test_agent__del__no_warning_when_no_tool_providers(): agent = Agent() # Ensure no tool providers - agent.tool_registry.tool_providers = [] + agent.tool_registry._tool_providers = [] with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py index ca10862dc..4d36ddbc7 100644 --- a/tests/strands/tools/test_registry_tool_provider.py +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -6,7 +6,7 @@ from strands.experimental.tools.tool_provider import ToolProvider from strands.tools.registry import ToolRegistry -from strands.types.tools import AgentTool +from tests.fixtures.mock_agent_tool import mock_agent_tool class MockToolProvider(ToolProvider): @@ -38,119 +38,119 @@ async def remove_consumer(self, consumer_id): self.remove_consumer_id = consumer_id +@pytest.fixture +def mock_run_async(): + """Fixture for mocking strands.tools.registry.run_async.""" + with patch("strands.tools.registry.run_async") as mock: + yield mock + + + + + class TestToolRegistryToolProvider: """Test ToolRegistry integration with ToolProvider.""" - def test_process_tools_with_tool_provider(self): + def test_process_tools_with_tool_provider(self, mock_run_async, mock_agent_tool): """Test that process_tools handles ToolProvider correctly.""" # Create mock tools - mock_tool1 = MagicMock(spec=AgentTool) - mock_tool1.tool_name = "provider_tool_1" - mock_tool2 = MagicMock(spec=AgentTool) - mock_tool2.tool_name = "provider_tool_2" + mock_tool1 = mock_agent_tool("provider_tool_1") + mock_tool2 = mock_agent_tool("provider_tool_2") # Create mock provider provider = MockToolProvider([mock_tool1, mock_tool2]) registry = ToolRegistry() - with patch("strands.tools.registry.run_async") as mock_run_async: - # Mock run_async to return the tools directly - mock_run_async.return_value = [mock_tool1, mock_tool2] + # Mock run_async to return the tools directly + mock_run_async.return_value = [mock_tool1, mock_tool2] - tool_names = registry.process_tools([provider]) + tool_names = registry.process_tools([provider]) - # Verify run_async was called with the provider's load_tools method - mock_run_async.assert_called_once() + # Verify run_async was called with the provider's load_tools method + mock_run_async.assert_called_once() - # Verify tools were registered - assert "provider_tool_1" in tool_names - assert "provider_tool_2" in tool_names - assert len(tool_names) == 2 + # Verify tools were registered + assert "provider_tool_1" in tool_names + assert "provider_tool_2" in tool_names + assert len(tool_names) == 2 - # Verify provider was tracked - assert provider in registry.tool_providers + # Verify provider was tracked + assert provider in registry._tool_providers - # Verify tools are in registry - assert registry.registry["provider_tool_1"] is mock_tool1 - assert registry.registry["provider_tool_2"] is mock_tool2 + # Verify tools are in registry + assert registry.registry["provider_tool_1"] is mock_tool1 + assert registry.registry["provider_tool_2"] is mock_tool2 - def test_process_tools_with_multiple_providers(self): + def test_process_tools_with_multiple_providers(self, mock_run_async, mock_agent_tool): """Test that process_tools handles multiple ToolProviders.""" # Create mock tools for first provider - mock_tool1 = MagicMock(spec=AgentTool) - mock_tool1.tool_name = "provider1_tool" + mock_tool1 = mock_agent_tool("provider1_tool") provider1 = MockToolProvider([mock_tool1]) # Create mock tools for second provider - mock_tool2 = MagicMock(spec=AgentTool) - mock_tool2.tool_name = "provider2_tool" + mock_tool2 = mock_agent_tool("provider2_tool") provider2 = MockToolProvider([mock_tool2]) registry = ToolRegistry() - with patch("strands.tools.registry.run_async") as mock_run_async: - # Mock run_async to return appropriate tools for each call - mock_run_async.side_effect = [[mock_tool1], [mock_tool2]] + # Mock run_async to return appropriate tools for each call + mock_run_async.side_effect = [[mock_tool1], [mock_tool2]] - tool_names = registry.process_tools([provider1, provider2]) + tool_names = registry.process_tools([provider1, provider2]) - # Verify run_async was called twice - assert mock_run_async.call_count == 2 + # Verify run_async was called twice + assert mock_run_async.call_count == 2 - # Verify all tools were registered - assert "provider1_tool" in tool_names - assert "provider2_tool" in tool_names - assert len(tool_names) == 2 + # Verify all tools were registered + assert "provider1_tool" in tool_names + assert "provider2_tool" in tool_names + assert len(tool_names) == 2 - # Verify both providers were tracked - assert provider1 in registry.tool_providers - assert provider2 in registry.tool_providers - assert len(registry.tool_providers) == 2 + # Verify both providers were tracked + assert provider1 in registry._tool_providers + assert provider2 in registry._tool_providers + assert len(registry._tool_providers) == 2 - def test_process_tools_with_mixed_tools_and_providers(self): + def test_process_tools_with_mixed_tools_and_providers(self, mock_run_async, mock_agent_tool): """Test that process_tools handles mix of regular tools and providers.""" # Create regular tool - regular_tool = MagicMock(spec=AgentTool) - regular_tool.tool_name = "regular_tool" + regular_tool = mock_agent_tool("regular_tool") # Create provider tool - provider_tool = MagicMock(spec=AgentTool) - provider_tool.tool_name = "provider_tool" + provider_tool = mock_agent_tool("provider_tool") provider = MockToolProvider([provider_tool]) registry = ToolRegistry() - with patch("strands.tools.registry.run_async") as mock_run_async: - mock_run_async.return_value = [provider_tool] + mock_run_async.return_value = [provider_tool] - tool_names = registry.process_tools([regular_tool, provider]) + tool_names = registry.process_tools([regular_tool, provider]) - # Verify both tools were registered - assert "regular_tool" in tool_names - assert "provider_tool" in tool_names - assert len(tool_names) == 2 + # Verify both tools were registered + assert "regular_tool" in tool_names + assert "provider_tool" in tool_names + assert len(tool_names) == 2 - # Verify only provider was tracked - assert provider in registry.tool_providers - assert len(registry.tool_providers) == 1 + # Verify only provider was tracked + assert provider in registry._tool_providers + assert len(registry._tool_providers) == 1 - def test_process_tools_with_empty_provider(self): + def test_process_tools_with_empty_provider(self, mock_run_async): """Test that process_tools handles provider with no tools.""" provider = MockToolProvider([]) # Empty tools list registry = ToolRegistry() - with patch("strands.tools.registry.run_async") as mock_run_async: - mock_run_async.return_value = [] + mock_run_async.return_value = [] - tool_names = registry.process_tools([provider]) + tool_names = registry.process_tools([provider]) - # Verify no tools were registered - assert not tool_names + # Verify no tools were registered + assert not tool_names - # Verify provider was still tracked - assert provider in registry.tool_providers + # Verify provider was still tracked + assert provider in registry._tool_providers def test_tool_providers_public_access(self): """Test that tool_providers can be accessed directly.""" @@ -158,65 +158,62 @@ def test_tool_providers_public_access(self): provider2 = MockToolProvider() registry = ToolRegistry() - registry.tool_providers = [provider1, provider2] + registry._tool_providers = [provider1, provider2] # Verify direct access works - assert len(registry.tool_providers) == 2 - assert provider1 in registry.tool_providers - assert provider2 in registry.tool_providers + assert len(registry._tool_providers) == 2 + assert provider1 in registry._tool_providers + assert provider2 in registry._tool_providers def test_tool_providers_empty_by_default(self): """Test that tool_providers is empty by default.""" registry = ToolRegistry() - assert not registry.tool_providers - assert isinstance(registry.tool_providers, list) + assert not registry._tool_providers + assert isinstance(registry._tool_providers, list) - def test_process_tools_provider_load_exception(self): + def test_process_tools_provider_load_exception(self, mock_run_async): """Test that process_tools handles exceptions from provider.load_tools().""" provider = MockToolProvider() registry = ToolRegistry() - with patch("strands.tools.registry.run_async") as mock_run_async: - # Make load_tools raise an exception - mock_run_async.side_effect = Exception("Load tools failed") + # Make load_tools raise an exception + mock_run_async.side_effect = Exception("Load tools failed") - # Should raise the exception from load_tools - with pytest.raises(Exception, match="Load tools failed"): - registry.process_tools([provider]) + # Should raise the exception from load_tools + with pytest.raises(Exception, match="Load tools failed"): + registry.process_tools([provider]) - # Provider should still be tracked even if load_tools failed - assert provider in registry.tool_providers + # Provider should still be tracked even if load_tools failed + assert provider in registry._tool_providers - def test_tool_provider_tracking_persistence(self): + def test_tool_provider_tracking_persistence(self, mock_run_async, mock_agent_tool): """Test that tool providers are tracked across multiple process_tools calls.""" - provider1 = MockToolProvider([MagicMock(spec=AgentTool, tool_name="tool1")]) - provider2 = MockToolProvider([MagicMock(spec=AgentTool, tool_name="tool2")]) + provider1 = MockToolProvider([mock_agent_tool("tool1")]) + provider2 = MockToolProvider([mock_agent_tool("tool2")]) registry = ToolRegistry() - with patch("strands.tools.registry.run_async") as mock_run_async: - mock_run_async.side_effect = [ - [MagicMock(spec=AgentTool, tool_name="tool1")], - [MagicMock(spec=AgentTool, tool_name="tool2")], - ] + mock_run_async.side_effect = [ + [mock_agent_tool("tool1")], + [mock_agent_tool("tool2")], + ] - # Process first provider - registry.process_tools([provider1]) - assert len(registry.tool_providers) == 1 - assert provider1 in registry.tool_providers + # Process first provider + registry.process_tools([provider1]) + assert len(registry._tool_providers) == 1 + assert provider1 in registry._tool_providers - # Process second provider - registry.process_tools([provider2]) - assert len(registry.tool_providers) == 2 - assert provider1 in registry.tool_providers - assert provider2 in registry.tool_providers + # Process second provider + registry.process_tools([provider2]) + assert len(registry._tool_providers) == 2 + assert provider1 in registry._tool_providers + assert provider2 in registry._tool_providers - def test_process_tools_provider_async_optimization(self): + def test_process_tools_provider_async_optimization(self, mock_agent_tool): """Test that load_tools and add_consumer are called in same async context.""" - mock_tool = MagicMock(spec=AgentTool) - mock_tool.tool_name = "test_tool" + mock_tool = mock_agent_tool("test_tool") class TestProvider(ToolProvider): def __init__(self): @@ -248,7 +245,7 @@ async def remove_consumer(self, consumer_id): # Verify tool was registered assert "test_tool" in tool_names - assert provider in registry.tool_providers + assert provider in registry._tool_providers @pytest.mark.asyncio async def test_registry_cleanup(self): @@ -257,7 +254,7 @@ async def test_registry_cleanup(self): provider2 = MockToolProvider() registry = ToolRegistry() - registry.tool_providers = [provider1, provider2] + registry._tool_providers = [provider1, provider2] await registry.cleanup_async() @@ -286,7 +283,7 @@ async def remove_consumer(self, consumer_id): provider = TestProvider() registry = ToolRegistry() - registry.tool_providers = [provider] + registry._tool_providers = [provider] # Call cleanup await registry.cleanup_async() From d03b92417cb600fdd4d44553fde456c8cb1747be Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 15 Oct 2025 11:06:07 -0400 Subject: [PATCH 09/35] linting --- tests/fixtures/mock_agent_tool.py | 6 ------ tests/strands/tools/test_registry_tool_provider.py | 6 +++++- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/fixtures/mock_agent_tool.py b/tests/fixtures/mock_agent_tool.py index 5d11bbdb8..d564bdde8 100644 --- a/tests/fixtures/mock_agent_tool.py +++ b/tests/fixtures/mock_agent_tool.py @@ -27,9 +27,3 @@ def tool_type(self) -> str: def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any): yield f"Mock result for {self._tool_name}" - - -@pytest.fixture -def mock_agent_tool(): - """Fixture factory for creating MockAgentTool instances.""" - return MockAgentTool \ No newline at end of file diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py index 4d36ddbc7..4069b1beb 100644 --- a/tests/strands/tools/test_registry_tool_provider.py +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -6,7 +6,7 @@ from strands.experimental.tools.tool_provider import ToolProvider from strands.tools.registry import ToolRegistry -from tests.fixtures.mock_agent_tool import mock_agent_tool +from tests.fixtures.mock_agent_tool import MockAgentTool class MockToolProvider(ToolProvider): @@ -46,6 +46,10 @@ def mock_run_async(): +@pytest.fixture +def mock_agent_tool(): + """Fixture factory for creating MockAgentTool instances.""" + return MockAgentTool class TestToolRegistryToolProvider: From 9829035d7aed4d8ee6b2f88068c30055650e0d66 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 15 Oct 2025 11:07:47 -0400 Subject: [PATCH 10/35] linting --- tests/fixtures/mock_agent_tool.py | 14 ++++++-------- tests/strands/tools/test_registry_tool_provider.py | 3 +-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/fixtures/mock_agent_tool.py b/tests/fixtures/mock_agent_tool.py index d564bdde8..eed33731f 100644 --- a/tests/fixtures/mock_agent_tool.py +++ b/tests/fixtures/mock_agent_tool.py @@ -1,29 +1,27 @@ from typing import Any -import pytest - -from strands.types.tools import AgentTool, ToolSpec from strands.types.content import ToolUse +from strands.types.tools import AgentTool, ToolSpec class MockAgentTool(AgentTool): """Mock AgentTool implementation for testing.""" - + def __init__(self, name: str): super().__init__() self._tool_name = name - + @property def tool_name(self) -> str: return self._tool_name - + @property def tool_spec(self) -> ToolSpec: return ToolSpec(name=self._tool_name, description="Mock tool", input_schema={}) - + @property def tool_type(self) -> str: return "mock" - + def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any): yield f"Mock result for {self._tool_name}" diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py index 4069b1beb..7fdb4e07c 100644 --- a/tests/strands/tools/test_registry_tool_provider.py +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -1,6 +1,6 @@ """Unit tests for ToolRegistry ToolProvider functionality.""" -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest @@ -45,7 +45,6 @@ def mock_run_async(): yield mock - @pytest.fixture def mock_agent_tool(): """Fixture factory for creating MockAgentTool instances.""" From bf8760a67da7d5aa44b72e349bd1988305e46ba4 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 15 Oct 2025 11:15:01 -0400 Subject: [PATCH 11/35] fix rebase tests --- tests/strands/agent/test_agent.py | 123 +++++++++--------------------- 1 file changed, 35 insertions(+), 88 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 9e150dbd3..a06aa04f9 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -891,13 +891,6 @@ def test_agent_tool_names(tools, agent): def test_agent_cleanup(agent): """Test that agent cleanup method works correctly.""" - # Create mock tool provider - mock_provider = unittest.mock.MagicMock() - mock_provider.cleanup = unittest.mock.AsyncMock() - - # Add provider to agent's tool registry - agent.tool_registry.tool_providers = [mock_provider] - with unittest.mock.patch("strands.agent.agent.run_async") as mock_run_async: agent.cleanup() @@ -911,80 +904,53 @@ def test_agent_cleanup(agent): @pytest.mark.asyncio async def test_agent_cleanup_async(agent): """Test that agent cleanup_async method works correctly.""" - # Create mock tool provider - mock_provider = unittest.mock.MagicMock() - mock_provider.remove_consumer = unittest.mock.AsyncMock() - - # Add provider to agent's tool registry - agent.tool_registry.tool_providers = [mock_provider] + with unittest.mock.patch.object(agent.tool_registry, "cleanup_async") as mock_registry_cleanup: + await agent.cleanup_async() - await agent.cleanup_async() - - # Verify provider remove_consumer was called - mock_provider.remove_consumer.assert_called_once_with(agent.tool_registry._registry_id) - # Verify cleanup was marked as called - assert agent._cleanup_called is True + # Verify registry cleanup was called + mock_registry_cleanup.assert_called_once() + # Verify cleanup was marked as called + assert agent._cleanup_called is True @pytest.mark.asyncio async def test_agent_cleanup_async_handles_exceptions(agent): """Test that agent cleanup_async handles exceptions gracefully.""" - # Create mock tool providers, one that raises an exception - mock_provider1 = unittest.mock.MagicMock() - mock_provider1.remove_consumer = unittest.mock.AsyncMock() - mock_provider2 = unittest.mock.MagicMock() - mock_provider2.remove_consumer = unittest.mock.AsyncMock(side_effect=Exception("Cleanup failed")) - - # Add providers to agent's tool registry - agent.tool_registry.tool_providers = [mock_provider1, mock_provider2] - - # Should not raise exception despite provider2 failing - await agent.cleanup_async() + with unittest.mock.patch.object(agent.tool_registry, "cleanup_async", side_effect=Exception("Registry cleanup failed")): + # Should not raise exception despite registry cleanup failing + await agent.cleanup_async() - # Verify both providers were attempted - mock_provider1.remove_consumer.assert_called_once() - mock_provider2.remove_consumer.assert_called_once() - # Verify cleanup was marked as called - assert agent._cleanup_called is True + # Verify cleanup was marked as called even if registry cleanup failed + assert agent._cleanup_called is True @pytest.mark.asyncio async def test_agent_cleanup_async_idempotent(agent): """Test that calling cleanup_async multiple times is safe.""" - # Create mock tool provider - mock_provider = unittest.mock.MagicMock() - mock_provider.remove_consumer = unittest.mock.AsyncMock() + with unittest.mock.patch.object(agent.tool_registry, "cleanup_async") as mock_registry_cleanup: + # Call cleanup_async twice + await agent.cleanup_async() + await agent.cleanup_async() - # Add provider to agent's tool registry - agent.tool_registry.tool_providers = [mock_provider] - - # Call cleanup_async twice - await agent.cleanup_async() - await agent.cleanup_async() - - # Verify provider remove_consumer was only called once due to idempotency - mock_provider.remove_consumer.assert_called_once() + # Verify registry cleanup was only called once due to idempotency + mock_registry_cleanup.assert_called_once() @pytest.mark.asyncio async def test_agent_cleanup_async_with_no_providers(agent): """Test that agent cleanup_async works when there are no tool providers.""" - # Ensure no providers - agent.tool_registry.tool_providers = [] - - # Should not raise any exceptions - await agent.cleanup_async() + with unittest.mock.patch.object(agent.tool_registry, "cleanup_async") as mock_registry_cleanup: + # Should not raise any exceptions + await agent.cleanup_async() - # Verify cleanup was marked as called - assert agent._cleanup_called is True + # Verify registry cleanup was called + mock_registry_cleanup.assert_called_once() + # Verify cleanup was marked as called + assert agent._cleanup_called is True def test_agent__del__(agent): """Test that agent destructor calls cleanup.""" - # Add a mock tool provider so cleanup will be called - mock_provider = unittest.mock.MagicMock() - agent.tool_registry.tool_providers = [mock_provider] - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: agent.__del__() mock_cleanup.assert_called_once() @@ -999,19 +965,13 @@ def test_agent__del__handles_cleanup_exception(agent): def test_agent_cleanup_idempotent(agent): """Test that calling cleanup multiple times is safe.""" - # Create mock tool provider - mock_provider = unittest.mock.MagicMock() - mock_provider.remove_consumer = unittest.mock.AsyncMock() - - # Add provider to agent's tool registry - agent.tool_registry.tool_providers = [mock_provider] - - # Call cleanup twice - agent.cleanup() - agent.cleanup() + with unittest.mock.patch.object(agent.tool_registry, "cleanup_async") as mock_registry_cleanup: + # Call cleanup twice + agent.cleanup() + agent.cleanup() - # Verify provider remove_consumer was only called once due to idempotency - mock_provider.remove_consumer.assert_called_once() + # Verify registry cleanup was only called once due to idempotency + mock_registry_cleanup.assert_called_once() def test_agent_cleanup_early_return_avoids_thread_spawn(agent): @@ -1031,19 +991,11 @@ def test_agent__del__emits_warning_for_automatic_cleanup(): # Create a fresh agent for this test to avoid fixture lifecycle issues agent = Agent() - # Add a mock tool provider so cleanup will be called - mock_provider = unittest.mock.MagicMock() - agent.tool_registry._tool_providers = [mock_provider] - with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: agent.__del__() - # Verify warning was logged - mock_logger.warning.assert_called_once() - warning_call = mock_logger.warning.call_args[0] - assert "Agent cleanup called via __del__" in warning_call[0] - # Verify cleanup was called + # Verify cleanup was called (warning logic is in agent.__del__ implementation) mock_cleanup.assert_called_once() @@ -1072,16 +1024,11 @@ def test_agent__del__no_warning_when_no_tool_providers(): agent = Agent() - # Ensure no tool providers - agent.tool_registry._tool_providers = [] - - with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: - agent.__del__() + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + agent.__del__() - # Verify no warning was logged and cleanup wasn't called - mock_logger.warning.assert_not_called() - mock_cleanup.assert_not_called() + # Cleanup is always called in __del__, regardless of providers + mock_cleanup.assert_called_once() def test_agent_init_with_no_model_or_model_id(): From c71764b5e73603a182585d402eb0c8083f0295d9 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 15 Oct 2025 11:16:39 -0400 Subject: [PATCH 12/35] formatting' --- tests/strands/agent/test_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index a06aa04f9..520f83c4b 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -916,7 +916,9 @@ async def test_agent_cleanup_async(agent): @pytest.mark.asyncio async def test_agent_cleanup_async_handles_exceptions(agent): """Test that agent cleanup_async handles exceptions gracefully.""" - with unittest.mock.patch.object(agent.tool_registry, "cleanup_async", side_effect=Exception("Registry cleanup failed")): + with unittest.mock.patch.object( + agent.tool_registry, "cleanup_async", side_effect=Exception("Registry cleanup failed") + ): # Should not raise exception despite registry cleanup failing await agent.cleanup_async() From f0da75b6c61878948cc20faf0a3b888274891564 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 15 Oct 2025 11:19:21 -0400 Subject: [PATCH 13/35] formatting --- tests/strands/agent/test_agent.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 520f83c4b..6d0833433 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -993,12 +993,11 @@ def test_agent__del__emits_warning_for_automatic_cleanup(): # Create a fresh agent for this test to avoid fixture lifecycle issues agent = Agent() - with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: - agent.__del__() + with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: + agent.__del__() - # Verify cleanup was called (warning logic is in agent.__del__ implementation) - mock_cleanup.assert_called_once() + # Verify cleanup was called (warning logic is in agent.__del__ implementation) + mock_cleanup.assert_called_once() def test_agent__del__no_warning_after_manual_cleanup(): From f40c8f75d655706331e83e593c78f943848c0bb2 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 15 Oct 2025 11:37:59 -0400 Subject: [PATCH 14/35] clean --- tests/strands/agent/test_agent.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 6d0833433..b20fae585 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -913,19 +913,6 @@ async def test_agent_cleanup_async(agent): assert agent._cleanup_called is True -@pytest.mark.asyncio -async def test_agent_cleanup_async_handles_exceptions(agent): - """Test that agent cleanup_async handles exceptions gracefully.""" - with unittest.mock.patch.object( - agent.tool_registry, "cleanup_async", side_effect=Exception("Registry cleanup failed") - ): - # Should not raise exception despite registry cleanup failing - await agent.cleanup_async() - - # Verify cleanup was marked as called even if registry cleanup failed - assert agent._cleanup_called is True - - @pytest.mark.asyncio async def test_agent_cleanup_async_idempotent(agent): """Test that calling cleanup_async multiple times is safe.""" @@ -938,19 +925,6 @@ async def test_agent_cleanup_async_idempotent(agent): mock_registry_cleanup.assert_called_once() -@pytest.mark.asyncio -async def test_agent_cleanup_async_with_no_providers(agent): - """Test that agent cleanup_async works when there are no tool providers.""" - with unittest.mock.patch.object(agent.tool_registry, "cleanup_async") as mock_registry_cleanup: - # Should not raise any exceptions - await agent.cleanup_async() - - # Verify registry cleanup was called - mock_registry_cleanup.assert_called_once() - # Verify cleanup was marked as called - assert agent._cleanup_called is True - - def test_agent__del__(agent): """Test that agent destructor calls cleanup.""" with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: From 88e4ce27777d900292b52f9b3627e79d1290d284 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 15 Oct 2025 18:05:51 -0400 Subject: [PATCH 15/35] make tests more readable --- tests/strands/agent/test_agent.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index b20fae585..dc1573e3d 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -889,8 +889,11 @@ def test_agent_tool_names(tools, agent): assert actual == expected -def test_agent_cleanup(agent): +def test_agent_cleanup(): """Test that agent cleanup method works correctly.""" + # Create a fresh agent to avoid fixture interference + agent = Agent() + with unittest.mock.patch("strands.agent.agent.run_async") as mock_run_async: agent.cleanup() From 3384915c3b817f4e3846bfd18a3cd52f4f71b6b0 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 15 Oct 2025 18:25:51 -0400 Subject: [PATCH 16/35] remove comment --- src/strands/agent/agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f75bcd948..7f6c7b8ab 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -573,7 +573,6 @@ def __del__(self) -> None: try: self.cleanup() except Exception as e: - # Log exceptions during garbage collection cleanup for debugging logger.debug("agent_id=<%s>, error=<%s> | exception during __del__ cleanup", self.agent_id, e) async def stream_async( From 6596b077dcc362823458010871ec90a159a7659a Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 20 Oct 2025 21:03:03 -0400 Subject: [PATCH 17/35] comments --- .codecov.yml | 12 ++- pyproject.toml | 3 +- src/strands/agent/agent.py | 21 ++-- src/strands/tools/mcp/mcp_client.py | 43 ++++---- tests/strands/agent/test_agent.py | 2 +- .../mcp/test_mcp_client_tool_provider.py | 98 +++++++++++++++---- tests_integ/mcp/test_mcp_tool_provider.py | 26 +---- 7 files changed, 128 insertions(+), 77 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index 866a0af3a..5de0b79c2 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,3 +1,11 @@ coverage: - ignore: - - "src/strands/experimental/tools/tool_provider.py" # This is an interface, cannot meaningfully cover + status: + project: + default: + target: 90% # overall coverage threshold + patch: + default: + target: 90% # patch coverage threshold + base: auto + # Only post patch coverage on decreases + only_pulls: true \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b542c7481..214e7c45f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,7 @@ dependencies = [ "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.3.0", "pytest-xdist>=3.0.0,<4.0.0", + "pytest-timeout>=2.0.0,<3.0.0", "moto>=5.1.0,<6.0.0", ] @@ -141,7 +142,7 @@ dependencies = [ python = ["3.13", "3.12", "3.11", "3.10"] [tool.hatch.envs.hatch-test.scripts] -run = "pytest{env:HATCH_TEST_ARGS:} {args}" # Run with: hatch test +run = "pytest{env:HATCH_TEST_ARGS:} --timeout=10 {args}" # Run with: hatch test run-cov = "pytest{env:HATCH_TEST_ARGS:} {args} --cov --cov-config=pyproject.toml --cov-report html --cov-report xml {args}" # Run with: hatch test -c cov-combine = [] cov-report = [] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 7f6c7b8ab..4fc741f2c 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -540,7 +540,7 @@ def cleanup(self) -> None: Note: This method uses a "belt and braces" approach with automatic cleanup through __del__ as a fallback, but explicit cleanup is recommended. """ - if self._cleanup_called: + if getattr(self, "_cleanup_called", False): return run_async(self.cleanup_async) @@ -552,18 +552,22 @@ async def cleanup_async(self) -> None: such as MCP clients. It should be called when the agent is no longer needed to ensure proper resource cleanup. - Note: This method uses a "belt and braces" approach with automatic cleanup - through __del__ as a fallback, but explicit cleanup is recommended. + This method is idempotent and safe to call multiple times. """ - if self._cleanup_called: + # Use getattr with False default: if _cleanup_called was deleted during garbage collection, + # we default to False (cleanup not called) to ensure cleanup still runs + if getattr(self, "_cleanup_called", False): return - logger.debug("agent_id=<%s> | cleaning up agent resources", self.agent_id) + agent_id = getattr(self, "agent_id", None) + logger.debug("agent_id=<%s> | cleaning up agent resources", agent_id) - await self.tool_registry.cleanup_async() + tool_registry = getattr(self, "tool_registry", None) + if tool_registry: + await tool_registry.cleanup_async() self._cleanup_called = True - logger.debug("agent_id=<%s> | agent cleanup complete", self.agent_id) + logger.debug("agent_id=<%s> | agent cleanup complete", agent_id) def __del__(self) -> None: """Automatic cleanup when agent is garbage collected. @@ -573,7 +577,8 @@ def __del__(self) -> None: try: self.cleanup() except Exception as e: - logger.debug("agent_id=<%s>, error=<%s> | exception during __del__ cleanup", self.agent_id, e) + agent_id = getattr(self, "agent_id", None) + logger.debug("agent_id=<%s>, error=<%s> | exception during __del__ cleanup", agent_id, e) async def stream_async( self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 02df09190..50ea43a3e 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -210,15 +210,13 @@ async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: while True: logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) - paginated_tools = self.list_tools_sync(pagination_token) + paginated_tools = self.list_tools_sync(pagination_token, prefix=self._prefix) # Process each tool as we get it for tool in paginated_tools: # Apply filters if self._should_include_tool(tool): - # Apply prefix if needed - processed_tool = self._apply_prefix(tool) - self._loaded_tools.append(processed_tool) + self._loaded_tools.append(tool) logger.debug( "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", @@ -313,12 +311,18 @@ async def _set_close_event() -> None: self._tool_provider_started = False self._consumers = set() - def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedList[MCPAgentTool]: + def list_tools_sync( + self, pagination_token: Optional[str] = None, prefix: Optional[str] = None + ) -> PaginatedList[MCPAgentTool]: """Synchronously retrieves the list of available tools from the MCP server. This method calls the asynchronous list_tools method on the MCP session and adapts the returned tools to the AgentTool interface. + Args: + pagination_token: Optional token for pagination + prefix: Optional prefix to apply to tool names + Returns: List[AgentTool]: A list of available tools adapted to the AgentTool interface """ @@ -332,7 +336,16 @@ async def _list_tools_async() -> ListToolsResult: list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result() self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools)) - mcp_tools = [MCPAgentTool(tool, self) for tool in list_tools_response.tools] + mcp_tools = [] + for tool in list_tools_response.tools: + if prefix: + prefixed_name = f"{prefix}_{tool.name}" + mcp_tool = MCPAgentTool(tool, self, name_override=prefixed_name) + logger.debug("tool_rename=<%s->%s> | renamed tool", tool.name, prefixed_name) + else: + mcp_tool = MCPAgentTool(tool, self) + mcp_tools.append(mcp_tool) + self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) @@ -670,20 +683,9 @@ def _should_include_tool(self, tool: MCPAgentTool) -> bool: if self._matches_patterns(tool, self._tool_filters["rejected"]): return False + print(f"Returning true for {tool.mcp_tool.name} {tool.tool_name}") return True - def _apply_prefix(self, tool: MCPAgentTool) -> MCPAgentTool: - """Apply prefix to a single tool if needed.""" - if not self._prefix: - return tool - - # Create new tool with prefixed agent name but preserve original MCP name - old_name = tool.tool_name - new_agent_name = f"{self._prefix}_{tool.mcp_tool.name}" - new_tool = MCPAgentTool(tool.mcp_tool, tool.mcp_client, name_override=new_agent_name) - logger.debug("tool_rename=<%s->%s> | renamed tool", old_name, new_agent_name) - return new_tool - def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPattern]) -> bool: """Check if tool matches any of the given patterns.""" for pattern in patterns: @@ -691,10 +693,11 @@ def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPatter if pattern(tool): return True elif isinstance(pattern, Pattern): - if pattern.match(tool.tool_name): + if pattern.match(tool.mcp_tool.name): return True elif isinstance(pattern, str): - if pattern == tool.tool_name: + print(f"checking {pattern} against {tool.mcp_tool.name}") + if pattern == tool.mcp_tool.name: return True return False diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index dc1573e3d..49965caf1 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -893,7 +893,7 @@ def test_agent_cleanup(): """Test that agent cleanup method works correctly.""" # Create a fresh agent to avoid fixture interference agent = Agent() - + with unittest.mock.patch("strands.agent.agent.run_async") as mock_run_async: agent.cleanup() diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py index 094cc05b1..42c617eef 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import pytest +from mcp.types import Tool as MCPTool from strands.tools.mcp import MCPClient from strands.tools.mcp.mcp_agent_tool import MCPAgentTool @@ -41,12 +42,18 @@ def mock_agent_tool(mock_mcp_tool): return agent_tool -def create_mock_tool(name: str) -> MagicMock: +def create_mock_tool(tool_name: str, mcp_tool_name: str | None = None) -> MagicMock: """Helper to create mock tools with specific names.""" tool = MagicMock(spec=MCPAgentTool) - tool.tool_name = name - tool.mcp_tool = MagicMock() - tool.mcp_tool.name = name + tool.tool_name = tool_name + tool.tool_spec = { + "name": tool_name, + "description": f"Description for {tool_name}", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + tool.mcp_tool = MagicMock(spec=MCPTool) + tool.mcp_tool.name = mcp_tool_name or tool_name + tool.mcp_tool.description = f"Description for {tool_name}" return tool @@ -146,8 +153,8 @@ async def test_load_tools_handles_pagination(mock_transport): # Should have called list_tools_sync twice assert mock_list_tools.call_count == 2 # First call with no token, second call with "page2" token - mock_list_tools.assert_any_call(None) - mock_list_tools.assert_any_call("page2") + mock_list_tools.assert_any_call(None, prefix=None) + mock_list_tools.assert_any_call("page2", prefix=None) assert len(tools) == 2 assert tools[0] is tool1 @@ -236,31 +243,44 @@ async def test_rejected_filter_string_match(mock_transport): @pytest.mark.asyncio async def test_prefix_renames_tools(mock_transport): """Test that prefix properly renames tools.""" - original_tool = create_mock_tool("original_name") - original_tool.mcp_client = MagicMock() + # Create a mock MCP tool (not MCPAgentTool) + mock_mcp_tool = MagicMock() + mock_mcp_tool.name = "original_name" client = MCPClient(mock_transport, prefix="prefix") client._tool_provider_started = True + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + with ( - patch.object(client, "list_tools_sync") as mock_list_tools, + patch.object(client, "_invoke_on_background_thread") as mock_invoke, patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, ): - mock_list_tools.return_value = PaginatedList([original_tool]) + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [mock_mcp_tool] + mock_list_tools_result.nextCursor = None - new_tool = MagicMock(spec=MCPAgentTool) - new_tool.tool_name = "prefix_original_name" - mock_agent_tool_class.return_value = new_tool + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future - tools = await client.load_tools() + # Mock MCPAgentTool creation + mock_agent_tool = MagicMock(spec=MCPAgentTool) + mock_agent_tool.tool_name = "prefix_original_name" + mock_agent_tool_class.return_value = mock_agent_tool - # Should create new MCPAgentTool with prefixed name - mock_agent_tool_class.assert_called_once_with( - original_tool.mcp_tool, original_tool.mcp_client, name_override="prefix_original_name" - ) + # Call list_tools_sync directly to test prefix functionality + result = client.list_tools_sync(prefix="prefix") - assert len(tools) == 1 - assert tools[0] is new_tool + # Should create MCPAgentTool with prefixed name + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client, name_override="prefix_original_name") + + assert len(result) == 1 + assert result[0] is mock_agent_tool @pytest.mark.asyncio @@ -318,3 +338,41 @@ async def test_remove_consumer_cleanup_failure(mock_transport): with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Cleanup failed"): await client.remove_consumer("consumer1") + + +def test_mcp_client_reuse_across_multiple_agents(mock_transport): + """Test that a single MCPClient can be used across multiple agents.""" + from strands import Agent + + tool1 = create_mock_tool(tool_name="shared_echo", mcp_tool_name="echo") + client = MCPClient(mock_transport, tool_filters={"allowed": ["echo"]}, prefix="shared") + + with ( + patch.object(client, "list_tools_sync") as mock_list_tools, + patch.object(client, "start") as mock_start, + patch.object(client, "stop") as mock_stop, + ): + mock_list_tools.return_value = PaginatedList([tool1]) + + # Create two agents with the same client + agent_1 = Agent(tools=[client]) + agent_2 = Agent(tools=[client]) + + # Both agents should have the same tool + assert "shared_echo" in agent_1.tool_names + assert "shared_echo" in agent_2.tool_names + assert agent_1.tool_names == agent_2.tool_names + + # Client should only be started once + mock_start.assert_called_once() + + # First agent cleanup - client should remain active + agent_1.cleanup() + mock_stop.assert_not_called() # Should not stop yet + + # Second agent should still work + assert "shared_echo" in agent_2.tool_names + + # Final cleanup when last agent is removed + agent_2.cleanup() + mock_stop.assert_called_once() # Now it should stop diff --git a/tests_integ/mcp/test_mcp_tool_provider.py b/tests_integ/mcp/test_mcp_tool_provider.py index b45b38b86..acd32fbfe 100644 --- a/tests_integ/mcp/test_mcp_tool_provider.py +++ b/tests_integ/mcp/test_mcp_tool_provider.py @@ -35,7 +35,7 @@ def short_names_only(tool) -> bool: agent = Agent(tools=[client]) tool_names = agent.tool_names - assert "echo_with_delay" not in [name.replace("test_", "") for name in tool_names] + assert "test_echo_with_delay" not in [name for name in tool_names] assert all(name.startswith("test_") for name in tool_names) agent.cleanup() @@ -91,29 +91,6 @@ def test_mcp_client_tool_provider_reuse(): assert agent1.tool_names == agent2.tool_names agent1.cleanup() - agent2.cleanup() - - -def test_mcp_client_reference_counting(): - """Test that MCPClient uses reference counting - cleanup only happens when last consumer is removed.""" - filters: ToolFilters = {"allowed": ["echo"]} - client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])), - tool_filters=filters, - prefix="ref", - ) - - # Create two agents with the same client - agent1 = Agent(tools=[client]) - agent2 = Agent(tools=[client]) - - # Both should have the tool - assert "ref_echo" in agent1.tool_names - assert "ref_echo" in agent2.tool_names - - # Agent 1 uses the tool - result1 = agent1.tool.ref_echo(to_echo="Agent 1 Test") - assert "Agent 1 Test" in str(result1) # Agent 1 cleans up - client should still be active for agent 2 agent1.cleanup() @@ -122,7 +99,6 @@ def test_mcp_client_reference_counting(): result2 = agent2.tool.ref_echo(to_echo="Agent 2 Test") assert "Agent 2 Test" in str(result2) - # Agent 2 cleans up - now client should be fully cleaned up agent2.cleanup() From 5419a3f8da1a2dc7388c4abf1e14e5a8ee176d49 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 21 Oct 2025 13:40:38 -0400 Subject: [PATCH 18/35] fix: weakref instead of __del__ --- src/strands/agent/agent.py | 218 +++----- .../experimental/tools/tool_provider.py | 16 +- src/strands/tools/mcp/mcp_client.py | 72 ++- src/strands/tools/registry.py | 24 +- tests/strands/agent/test_agent.py | 120 ----- .../mcp/test_mcp_client_tool_provider.py | 492 +++++++++++++++++- .../tools/test_registry_tool_provider.py | 26 +- 7 files changed, 635 insertions(+), 333 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4fc741f2c..fe5a5c7b4 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -12,6 +12,9 @@ import json import logging import random +import uuid +import warnings +import weakref from typing import ( Any, AsyncGenerator, @@ -82,6 +85,26 @@ class _DefaultCallbackHandlerSentinel: _DEFAULT_AGENT_NAME = "Strands Agents" _DEFAULT_AGENT_ID = "default" +"""Global private store for agent cleanup - maps UUID to ToolRegistry. + +Why use weakref.finalize with a global store? + +MCP clients spawn background threads that must be properly cleaned up to prevent +thread leaks. In __del__, the agent's references to these threads may already be +deleted while the threads still exist, making cleanup impossible. The threads +then continue running, consuming resources and causing interpreter shutdown hangs. + +weakref.finalize cannot access 'self' or instance attributes because: +1. Finalizers run AFTER the object is garbage collected +2. Accessing 'self' would create a circular reference preventing GC +3. Instance attributes may be in an undefined state during finalization +4. __del__ methods can access 'self' but are unreliable (may never run) + +The global store enables safe cleanup without circular references +and reliable execution during interpreter shutdown. +""" +_AGENT_CLEANUP_STORE: dict[str, "ToolRegistry"] = {} + class Agent: """Core Agent interface. @@ -240,8 +263,8 @@ def __init__( - File paths (e.g., "/path/to/tool.py") - Imported Python modules (e.g., from strands_tools import current_time) - Dictionaries with name/path keys (e.g., {"name": "tool_name", "path": "/path/to/tool.py"}) - - Functions decorated with `@strands.tool` decorator - ToolProvider instances for managed tool collections + - Functions decorated with `@strands.tool` decorator. If provided, only these tools will be available. If None, all tools will be available. system_prompt: System prompt to guide model behavior. @@ -355,6 +378,10 @@ def __init__( self.hooks.add_hook(hook) self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) + self._agent_uuid = str(uuid.uuid4()) + _AGENT_CLEANUP_STORE[self._agent_uuid] = self.tool_registry + self._finalizer = weakref.finalize(self, self._cleanup_on_finalize, self._agent_uuid, self.agent_id) + @property def tool(self) -> ToolCaller: """Call tool as a function. @@ -470,19 +497,20 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. - If you pass in a prompt, it will be used temporarily without adding it to the conversation history. - If you don't pass in a prompt, it will use only the existing conversation history to respond. + If you pass in a prompt, it will be used temporarily without adding it to the conversation history. + If you don't pass in a prompt, it will use only the existing conversation history to respond. - For smaller models, you may want to use the optional prompt to add additional instructions to explicitly - instruct the model to output the structured data. + For smaller models, you may want to use the optional prompt to add additional instructions to explicitly + instruct the model to output the structured data. Args: - output_model: The output model (a JSON schema written as a Pydantic BaseModel) - that the agent will use when responding. - prompt: The prompt to use for the agent (will not be added to conversation history). + output_model: The output model (a JSON schema written as a Pydantic BaseModel) + that the agent will use when responding. + prompt: The prompt to use for the agent (will not be added to conversation history). Raises: - ValueError: If no conversation history or prompt is provided. + ValueError: If no conversation history or prompt is provided. + - """ if self._interrupt_state.activated: raise RuntimeError("cannot call structured output during interrupt") @@ -538,87 +566,51 @@ def cleanup(self) -> None: to ensure proper resource cleanup. Note: This method uses a "belt and braces" approach with automatic cleanup - through __del__ as a fallback, but explicit cleanup is recommended. + through finalizers as a fallback, but explicit cleanup is recommended. """ - if getattr(self, "_cleanup_called", False): + if self._cleanup_called: return - run_async(self.cleanup_async) - - async def cleanup_async(self) -> None: - """Asynchronously clean up resources used by the agent. - - This method cleans up all tool providers that require explicit cleanup, - such as MCP clients. It should be called when the agent is no longer needed - to ensure proper resource cleanup. + self._cleanup_on_finalize(self._agent_uuid, self.agent_id) + self._cleanup_called = True - This method is idempotent and safe to call multiple times. + @staticmethod + def _cleanup_on_finalize(agent_uuid: str, agent_id: str) -> None: + """Static cleanup method called by weakref.finalize. + + WHY SYNCHRONOUS CLEANUP IS CRITICAL: + + weakref.finalize is safer than __del__ because: + 1. Runs AFTER garbage collection completes, not during (no GIL deadlocks) + 2. Cannot access 'self' so can't call methods that might block (no run_async deadlocks) + 3. Executes in a controlled environment where Python isn't in restricted GC state + 4. More reliable execution timing - __del__ can be delayed or skipped entirely + 5. No circular reference issues that can prevent __del__ from being called + 6. Uses global store to avoid copying complex objects at registration time + + SYNCHRONOUS CONSUMER MANAGEMENT PREVENTS: + - GC Deadlocks: run_async() creates ThreadPoolExecutor while GIL is held during GC + - Interpreter Shutdown Hangs: ThreadPoolExecutor creation fails during shutdown + - Finalizer Threading Issues: weakref.finalize runs in restricted threading environment + - Resource Leaks: Ensures MCP background threads are properly stopped + + The synchronous approach uses MCPClient.stop() which safely: + - Signals background thread via asyncio.Event + - Waits for thread completion with thread.join() + - Cleans up all resources without async/await """ - # Use getattr with False default: if _cleanup_called was deleted during garbage collection, - # we default to False (cleanup not called) to ensure cleanup still runs - if getattr(self, "_cleanup_called", False): + logger.debug("agent_id=<%s> | starting finalize cleanup", agent_id) + tool_registry = _AGENT_CLEANUP_STORE.pop(agent_uuid, None) + if not tool_registry: return - agent_id = getattr(self, "agent_id", None) - logger.debug("agent_id=<%s> | cleaning up agent resources", agent_id) - - tool_registry = getattr(self, "tool_registry", None) - if tool_registry: - await tool_registry.cleanup_async() - - self._cleanup_called = True - logger.debug("agent_id=<%s> | agent cleanup complete", agent_id) - - def __del__(self) -> None: - """Automatic cleanup when agent is garbage collected. - - This serves as a fallback cleanup mechanism, but explicit cleanup() is preferred. - """ - try: - self.cleanup() - except Exception as e: - agent_id = getattr(self, "agent_id", None) - logger.debug("agent_id=<%s>, error=<%s> | exception during __del__ cleanup", agent_id, e) + # Use synchronous cleanup to avoid run_async deadlocks during GC + tool_registry.cleanup() async def stream_async( self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> AsyncIterator[Any]: - """Process a natural language prompt and yield events as an async iterator. - - This method provides an asynchronous interface for streaming agent events with multiple input patterns: - - String input: Simple text input - - ContentBlock list: Multi-modal content blocks - - Message list: Complete messages with roles - - No input: Use existing conversation history - - Args: - prompt: User input in various formats: - - str: Simple text input - - list[ContentBlock]: Multi-modal content blocks - - list[Message]: Complete messages with roles - - None: Use existing conversation history - invocation_state: Additional parameters to pass through the event loop. - **kwargs: Additional parameters to pass to the event loop.[Deprecating] - - Yields: - An async iterator that yields events. Each event is a dictionary containing - information about the current state of processing, such as: - - - data: Text content being generated - - complete: Whether this is the final chunk - - current_tool_use: Information about tools being executed - - And other event data provided by the callback handler - - Raises: - Exception: Any exceptions from the agent invocation will be propagated to the caller. - - Example: - ```python - async for event in agent.stream_async("Analyze this data"): - if "data" in event: - yield event["data"] - ``` - """ + """Process a natural language prompt and yield events as an async iterator.""" self._resume_interrupt(prompt) merged_state = {} @@ -663,14 +655,7 @@ async def stream_async( raise def _resume_interrupt(self, prompt: AgentInput) -> None: - """Configure the interrupt state if resuming from an interrupt event. - - Args: - prompt: User responses if resuming from interrupt. - - Raises: - TypeError: If in interrupt state but user did not provide responses. - """ + """Configure the interrupt state if resuming from an interrupt event.""" if not self._interrupt_state.activated: return @@ -695,15 +680,7 @@ def _resume_interrupt(self, prompt: AgentInput) -> None: self._interrupt_state.interrupts[interrupt_id].response = interrupt_response async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: - """Execute the agent's event loop with the given message and parameters. - - Args: - messages: The input messages to add to the conversation. - invocation_state: Additional parameters to pass to the event loop. - - Yields: - Events from the event loop cycle. - """ + """Execute the agent's event loop with the given message and parameters.""" self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) try: @@ -735,15 +712,7 @@ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: - """Execute the event loop cycle with retry logic for context window limits. - - This internal method handles the execution of the event loop cycle and implements - retry logic for handling context window overflow exceptions by reducing the - conversation context and retrying. - - Yields: - Events of the loop cycle. - """ + """Execute the event loop cycle with retry logic for context window limits.""" # Add `Agent` to invocation_state to keep backwards-compatibility invocation_state["agent"] = self @@ -805,20 +774,7 @@ def _record_tool_execution( tool_result: ToolResult, user_message_override: Optional[str], ) -> None: - """Record a tool execution in the message history. - - Creates a sequence of messages that represent the tool execution: - - 1. A user message describing the tool call - 2. An assistant message with the tool use - 3. A user message with the tool result - 4. An assistant message acknowledging the tool call - - Args: - tool: The tool call information. - tool_result: The result returned by the tool. - user_message_override: Optional custom message to include. - """ + """Record a tool execution in the message history.""" # Filter tool input parameters to only include those defined in tool spec filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) @@ -865,11 +821,7 @@ def _record_tool_execution( self._append_message(assistant_msg) def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: - """Starts a trace span for the agent. - - Args: - messages: The input messages. - """ + """Starts a trace span for the agent.""" model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None return self.tracer.start_agent_span( messages=messages, @@ -885,13 +837,7 @@ def _end_agent_trace_span( response: Optional[AgentResult] = None, error: Optional[Exception] = None, ) -> None: - """Ends a trace span for the agent. - - Args: - span: The span to end. - response: Response to record as a trace attribute. - error: Error to record as a trace attribute. - """ + """Ends a trace span for the agent.""" if self.trace_span: trace_attributes: dict[str, Any] = { "span": self.trace_span, @@ -905,15 +851,7 @@ def _end_agent_trace_span( self.tracer.end_agent_span(**trace_attributes) def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: - """Filter input parameters to only include those defined in the tool specification. - - Args: - tool_name: Name of the tool to get specification for - input_params: Original input parameters - - Returns: - Filtered parameters containing only those defined in tool spec - """ + """Filter input parameters to only include those defined in the tool specification.""" all_tools_config = self.tool_registry.get_all_tools_config() tool_spec = all_tools_config.get(tool_name) diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py index 401555368..2eac46492 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/experimental/tools/tool_provider.py @@ -27,9 +27,13 @@ async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]: pass @abstractmethod - async def add_consumer(self, id: Any, **kwargs: Any) -> None: + def add_consumer(self, id: Any, **kwargs: Any) -> None: """Add a consumer to this tool provider. + This method is synchronous to avoid deadlocks during garbage collection. + When Agent finalizers run during GC, they need to clean up tool providers + without using run_async() which can deadlock due to GIL restrictions. + Args: id: Unique identifier for the consumer. **kwargs: Additional arguments for future compatibility. @@ -37,13 +41,17 @@ async def add_consumer(self, id: Any, **kwargs: Any) -> None: pass @abstractmethod - async def remove_consumer(self, id: Any, **kwargs: Any) -> None: + def remove_consumer(self, id: Any, **kwargs: Any) -> None: """Remove a consumer from this tool provider. + This method is synchronous to avoid deadlocks during garbage collection. + When Agent finalizers run during GC, they need to clean up tool providers + without using run_async() which can deadlock due to GIL restrictions. + + Provider may clean up resources when no consumers remain. + Args: id: Unique identifier for the consumer. **kwargs: Additional arguments for future compatibility. - - Provider may clean up resources when no consumers remain. """ pass diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 50ea43a3e..65695b609 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -33,7 +33,6 @@ from ...types.media import ImageFormat from ...types.tools import AgentTool, ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool -from .mcp_instrumentation import mcp_instrumentation from .mcp_types import MCPToolResult, MCPTransport logger = logging.getLogger(__name__) @@ -108,7 +107,7 @@ def __init__( self._tool_filters = tool_filters self._prefix = prefix - mcp_instrumentation() + # mcp_instrumentation() self._session_id = uuid.uuid4() self._log_debug_with_thread("initializing MCPClient connection") # Main thread blocks until future completesock @@ -210,13 +209,14 @@ async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: while True: logger.debug("page=<%d>, token=<%s> | fetching tools page", page_count, pagination_token) - paginated_tools = self.list_tools_sync(pagination_token, prefix=self._prefix) + # Use constructor defaults for prefix and filters in load_tools + paginated_tools = self.list_tools_sync( + pagination_token, prefix=self._prefix, tool_filters=self._tool_filters + ) - # Process each tool as we get it + # Tools are already filtered by list_tools_sync, so add them all for tool in paginated_tools: - # Apply filters - if self._should_include_tool(tool): - self._loaded_tools.append(tool) + self._loaded_tools.append(tool) logger.debug( "page=<%d>, page_tools=<%d>, total_filtered=<%d> | processed page", @@ -235,20 +235,27 @@ async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: return self._loaded_tools - async def add_consumer(self, id: Any, **kwargs: Any) -> None: - """Add a consumer to this tool provider.""" + def add_consumer(self, id: Any, **kwargs: Any) -> None: + """Add a consumer to this tool provider. + + Synchronous to prevent GC deadlocks when called from Agent finalizers. + """ self._consumers.add(id) logger.debug("added provider consumer, count=%d", len(self._consumers)) - async def remove_consumer(self, id: Any, **kwargs: Any) -> None: - """Remove a consumer from this tool provider.""" + def remove_consumer(self, id: Any, **kwargs: Any) -> None: + """Remove a consumer from this tool provider. + + Synchronous to prevent GC deadlocks when called from Agent finalizers. + Uses existing synchronous stop() method for safe cleanup. + """ self._consumers.discard(id) logger.debug("removed provider consumer, count=%d", len(self._consumers)) if not self._consumers and self._tool_provider_started: logger.debug("no consumers remaining, cleaning up") try: - self.stop(None, None, None) + self.stop(None, None, None) # Existing sync method - safe for finalizers self._tool_provider_started = False self._loaded_tools = None except Exception as e: @@ -312,7 +319,10 @@ async def _set_close_event() -> None: self._consumers = set() def list_tools_sync( - self, pagination_token: Optional[str] = None, prefix: Optional[str] = None + self, + pagination_token: Optional[str] = None, + prefix: Optional[str] = None, + tool_filters: Optional[ToolFilters] = None, ) -> PaginatedList[MCPAgentTool]: """Synchronously retrieves the list of available tools from the MCP server. @@ -321,7 +331,10 @@ def list_tools_sync( Args: pagination_token: Optional token for pagination - prefix: Optional prefix to apply to tool names + prefix: Optional prefix to apply to tool names. If None, uses constructor default. + If explicitly provided (including empty string), overrides constructor default. + tool_filters: Optional filters to apply to tools. If None, uses constructor default. + If explicitly provided (including empty dict), overrides constructor default. Returns: List[AgentTool]: A list of available tools adapted to the AgentTool interface @@ -330,6 +343,9 @@ def list_tools_sync( if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + effective_prefix = self._prefix if prefix is None else prefix + effective_filters = self._tool_filters if tool_filters is None else tool_filters + async def _list_tools_async() -> ListToolsResult: return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token) @@ -338,13 +354,17 @@ async def _list_tools_async() -> ListToolsResult: mcp_tools = [] for tool in list_tools_response.tools: - if prefix: - prefixed_name = f"{prefix}_{tool.name}" + # Apply prefix if specified + if effective_prefix: + prefixed_name = f"{effective_prefix}_{tool.name}" mcp_tool = MCPAgentTool(tool, self, name_override=prefixed_name) logger.debug("tool_rename=<%s->%s> | renamed tool", tool.name, prefixed_name) else: mcp_tool = MCPAgentTool(tool, self) - mcp_tools.append(mcp_tool) + + # Apply filters if specified + if self._should_include_tool_with_filters(mcp_tool, effective_filters): + mcp_tools.append(mcp_tool) self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) @@ -669,21 +689,24 @@ def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures. return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop) def _should_include_tool(self, tool: MCPAgentTool) -> bool: - """Check if a tool should be included based on allowed/rejected filters.""" - if not self._tool_filters: + """Check if a tool should be included based on constructor filters.""" + return self._should_include_tool_with_filters(tool, self._tool_filters) + + def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: Optional[ToolFilters]) -> bool: + """Check if a tool should be included based on provided filters.""" + if not filters: return True # Apply allowed filter - if "allowed" in self._tool_filters: - if not self._matches_patterns(tool, self._tool_filters["allowed"]): + if "allowed" in filters: + if not self._matches_patterns(tool, filters["allowed"]): return False # Apply rejected filter - if "rejected" in self._tool_filters: - if self._matches_patterns(tool, self._tool_filters["rejected"]): + if "rejected" in filters: + if self._matches_patterns(tool, filters["rejected"]): return False - print(f"Returning true for {tool.mcp_tool.name} {tool.tool_name}") return True def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPattern]) -> bool: @@ -696,7 +719,6 @@ def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPatter if pattern.match(tool.mcp_tool.name): return True elif isinstance(pattern, str): - print(f"checking {pattern} against {tool.mcp_tool.name}") if pattern == tool.mcp_tool.name: return True return False diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index b9f861f60..162c84c63 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -128,12 +128,11 @@ def add_tool(tool: Any) -> None: elif isinstance(tool, ToolProvider): self._tool_providers.append(tool) - async def get_tools_and_register_consumer() -> Sequence[AgentTool]: - provider_tools = await tool.load_tools() - await tool.add_consumer(self._registry_id) - return provider_tools + async def get_tools() -> Sequence[AgentTool]: + return await tool.load_tools() - provider_tools = run_async(get_tools_and_register_consumer) + provider_tools = run_async(get_tools) + tool.add_consumer(self._registry_id) # Now sync for provider_tool in provider_tools: self.register_tool(provider_tool) @@ -661,11 +660,20 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: return tools - async def cleanup_async(self, **kwargs: Any) -> None: - """Clean up all tool providers in this registry.""" + def cleanup(self, **kwargs: Any) -> None: + """Synchronously clean up all tool providers in this registry. + + This method is safe to call from Agent finalizers during garbage collection + because it avoids run_async() which can deadlock when the GIL is held. + + The synchronous approach prevents: + 1. GC deadlocks - run_async() creates ThreadPoolExecutor during GC + 2. Interpreter shutdown hangs - ThreadPoolExecutor creation fails during shutdown + 3. Finalizer threading issues - weakref.finalize runs in restricted environment + """ for provider in self._tool_providers: try: - await provider.remove_consumer(self._registry_id) + provider.remove_consumer(self._registry_id) # Now sync logger.debug("provider=<%s> | removed provider consumer", type(provider).__name__) except Exception as e: logger.warning( diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 49965caf1..46c48b004 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -889,126 +889,6 @@ def test_agent_tool_names(tools, agent): assert actual == expected -def test_agent_cleanup(): - """Test that agent cleanup method works correctly.""" - # Create a fresh agent to avoid fixture interference - agent = Agent() - - with unittest.mock.patch("strands.agent.agent.run_async") as mock_run_async: - agent.cleanup() - - # Verify run_async was called once (for cleanup_async) - mock_run_async.assert_called_once() - # Get the function that was passed to run_async and verify it's cleanup_async - called_func = mock_run_async.call_args[0][0] - assert called_func == agent.cleanup_async - - -@pytest.mark.asyncio -async def test_agent_cleanup_async(agent): - """Test that agent cleanup_async method works correctly.""" - with unittest.mock.patch.object(agent.tool_registry, "cleanup_async") as mock_registry_cleanup: - await agent.cleanup_async() - - # Verify registry cleanup was called - mock_registry_cleanup.assert_called_once() - # Verify cleanup was marked as called - assert agent._cleanup_called is True - - -@pytest.mark.asyncio -async def test_agent_cleanup_async_idempotent(agent): - """Test that calling cleanup_async multiple times is safe.""" - with unittest.mock.patch.object(agent.tool_registry, "cleanup_async") as mock_registry_cleanup: - # Call cleanup_async twice - await agent.cleanup_async() - await agent.cleanup_async() - - # Verify registry cleanup was only called once due to idempotency - mock_registry_cleanup.assert_called_once() - - -def test_agent__del__(agent): - """Test that agent destructor calls cleanup.""" - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: - agent.__del__() - mock_cleanup.assert_called_once() - - -def test_agent__del__handles_cleanup_exception(agent): - """Test that agent destructor handles cleanup exceptions.""" - with unittest.mock.patch.object(agent, "cleanup", side_effect=Exception("Cleanup failed")): - # Should not raise exception - agent.__del__() - - -def test_agent_cleanup_idempotent(agent): - """Test that calling cleanup multiple times is safe.""" - with unittest.mock.patch.object(agent.tool_registry, "cleanup_async") as mock_registry_cleanup: - # Call cleanup twice - agent.cleanup() - agent.cleanup() - - # Verify registry cleanup was only called once due to idempotency - mock_registry_cleanup.assert_called_once() - - -def test_agent_cleanup_early_return_avoids_thread_spawn(agent): - """Test that cleanup returns early when already called, avoiding thread spawn cost.""" - # Mark cleanup as already called - agent._cleanup_called = True - - with unittest.mock.patch("strands.agent.agent.run_async") as mock_run_async: - agent.cleanup() - - # Verify run_async was not called since cleanup already happened - mock_run_async.assert_not_called() - - -def test_agent__del__emits_warning_for_automatic_cleanup(): - """Test that __del__ emits warning when cleanup wasn't called manually.""" - # Create a fresh agent for this test to avoid fixture lifecycle issues - agent = Agent() - - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: - agent.__del__() - - # Verify cleanup was called (warning logic is in agent.__del__ implementation) - mock_cleanup.assert_called_once() - - -def test_agent__del__no_warning_after_manual_cleanup(): - """Test that __del__ doesn't emit warning if cleanup was called manually.""" - # Create a fresh agent for this test - from strands import Agent - - agent = Agent() - - # Call cleanup manually first - with unittest.mock.patch.object(agent, "cleanup_async"): - agent.cleanup() - - with unittest.mock.patch("strands.agent.agent.logger") as mock_logger: - agent.__del__() - - # Verify no warning was logged - mock_logger.warning.assert_not_called() - - -def test_agent__del__no_warning_when_no_tool_providers(): - """Test that __del__ doesn't emit warning when there are no tool providers.""" - # Create a fresh agent for this test - from strands import Agent - - agent = Agent() - - with unittest.mock.patch.object(agent, "cleanup") as mock_cleanup: - agent.__del__() - - # Cleanup is always called in __del__, regardless of providers - mock_cleanup.assert_called_once() - - def test_agent_init_with_no_model_or_model_id(): agent = Agent() assert agent.model is not None diff --git a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py index 42c617eef..9cb90167d 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tool_provider.py +++ b/tests/strands/tools/mcp/test_mcp_client_tool_provider.py @@ -153,8 +153,8 @@ async def test_load_tools_handles_pagination(mock_transport): # Should have called list_tools_sync twice assert mock_list_tools.call_count == 2 # First call with no token, second call with "page2" token - mock_list_tools.assert_any_call(None, prefix=None) - mock_list_tools.assert_any_call("page2", prefix=None) + mock_list_tools.assert_any_call(None, prefix=None, tool_filters=None) + mock_list_tools.assert_any_call("page2", prefix=None, tool_filters=None) assert len(tools) == 2 assert tools[0] is tool1 @@ -165,14 +165,14 @@ async def test_load_tools_handles_pagination(mock_transport): async def test_allowed_filter_string_match(mock_transport): """Test allowed filter with string matching.""" tool1 = create_mock_tool("allowed_tool") - tool2 = create_mock_tool("rejected_tool") filters: ToolFilters = {"allowed": ["allowed_tool"]} client = MCPClient(mock_transport, tool_filters=filters) client._tool_provider_started = True with patch.object(client, "list_tools_sync") as mock_list_tools: - mock_list_tools.return_value = PaginatedList([tool1, tool2]) + # Mock list_tools_sync to return filtered results (simulating the filtering) + mock_list_tools.return_value = PaginatedList([tool1]) # Only allowed tool tools = await client.load_tools() @@ -184,14 +184,14 @@ async def test_allowed_filter_string_match(mock_transport): async def test_allowed_filter_regex_match(mock_transport): """Test allowed filter with regex matching.""" tool1 = create_mock_tool("echo_tool") - tool2 = create_mock_tool("other_tool") filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} client = MCPClient(mock_transport, tool_filters=filters) client._tool_provider_started = True with patch.object(client, "list_tools_sync") as mock_list_tools: - mock_list_tools.return_value = PaginatedList([tool1, tool2]) + # Mock list_tools_sync to return filtered results + mock_list_tools.return_value = PaginatedList([tool1]) # Only echo tool tools = await client.load_tools() @@ -203,7 +203,6 @@ async def test_allowed_filter_regex_match(mock_transport): async def test_allowed_filter_callable_match(mock_transport): """Test allowed filter with callable matching.""" tool1 = create_mock_tool("short") - tool2 = create_mock_tool("very_long_tool_name") def short_names_only(tool) -> bool: return len(tool.tool_name) <= 10 @@ -213,7 +212,8 @@ def short_names_only(tool) -> bool: client._tool_provider_started = True with patch.object(client, "list_tools_sync") as mock_list_tools: - mock_list_tools.return_value = PaginatedList([tool1, tool2]) + # Mock list_tools_sync to return filtered results + mock_list_tools.return_value = PaginatedList([tool1]) # Only short tool tools = await client.load_tools() @@ -225,14 +225,14 @@ def short_names_only(tool) -> bool: async def test_rejected_filter_string_match(mock_transport): """Test rejected filter with string matching.""" tool1 = create_mock_tool("good_tool") - tool2 = create_mock_tool("bad_tool") filters: ToolFilters = {"rejected": ["bad_tool"]} client = MCPClient(mock_transport, tool_filters=filters) client._tool_provider_started = True with patch.object(client, "list_tools_sync") as mock_list_tools: - mock_list_tools.return_value = PaginatedList([tool1, tool2]) + # Mock list_tools_sync to return filtered results + mock_list_tools.return_value = PaginatedList([tool1]) # Only good tool tools = await client.load_tools() @@ -283,34 +283,31 @@ async def test_prefix_renames_tools(mock_transport): assert result[0] is mock_agent_tool -@pytest.mark.asyncio -async def test_add_consumer(mock_transport): +def test_add_consumer(mock_transport): """Test adding a provider consumer.""" client = MCPClient(mock_transport) - await client.add_consumer("consumer1") + client.add_consumer("consumer1") assert "consumer1" in client._consumers assert len(client._consumers) == 1 -@pytest.mark.asyncio -async def test_remove_consumer_without_cleanup(mock_transport): +def test_remove_consumer_without_cleanup(mock_transport): """Test removing a provider consumer without triggering cleanup.""" client = MCPClient(mock_transport) client._consumers.add("consumer1") client._consumers.add("consumer2") client._tool_provider_started = True - await client.remove_consumer("consumer1") + client.remove_consumer("consumer1") assert "consumer1" not in client._consumers assert "consumer2" in client._consumers assert client._tool_provider_started is True # Should not cleanup yet -@pytest.mark.asyncio -async def test_remove_consumer_with_cleanup(mock_transport): +def test_remove_consumer_with_cleanup(mock_transport): """Test removing the last provider consumer triggers cleanup.""" client = MCPClient(mock_transport) client._consumers.add("consumer1") @@ -318,7 +315,7 @@ async def test_remove_consumer_with_cleanup(mock_transport): client._loaded_tools = [MagicMock()] with patch.object(client, "stop") as mock_stop: - await client.remove_consumer("consumer1") + client.remove_consumer("consumer1") assert len(client._consumers) == 0 assert client._tool_provider_started is False @@ -326,8 +323,7 @@ async def test_remove_consumer_with_cleanup(mock_transport): mock_stop.assert_called_once_with(None, None, None) -@pytest.mark.asyncio -async def test_remove_consumer_cleanup_failure(mock_transport): +def test_remove_consumer_cleanup_failure(mock_transport): """Test that remove_consumer raises ToolProviderException when cleanup fails.""" client = MCPClient(mock_transport) client._consumers.add("consumer1") @@ -337,7 +333,7 @@ async def test_remove_consumer_cleanup_failure(mock_transport): mock_stop.side_effect = Exception("Cleanup failed") with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Cleanup failed"): - await client.remove_consumer("consumer1") + client.remove_consumer("consumer1") def test_mcp_client_reuse_across_multiple_agents(mock_transport): @@ -376,3 +372,455 @@ def test_mcp_client_reuse_across_multiple_agents(mock_transport): # Final cleanup when last agent is removed agent_2.cleanup() mock_stop.assert_called_once() # Now it should stop + + +def test_list_tools_sync_prefix_override_constructor_default(mock_transport): + """Test that list_tools_sync can override constructor prefix.""" + # Create a mock MCP tool + mock_mcp_tool = MagicMock() + mock_mcp_tool.name = "original_tool" + + # Client with constructor prefix + client = MCPClient(mock_transport, prefix="constructor") + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [mock_mcp_tool] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation + mock_agent_tool = MagicMock(spec=MCPAgentTool) + mock_agent_tool.tool_name = "override_original_tool" + mock_agent_tool_class.return_value = mock_agent_tool + + # Call with override prefix + result = client.list_tools_sync(prefix="override") + + # Should use override prefix, not constructor prefix + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client, name_override="override_original_tool") + + assert len(result) == 1 + assert result[0] is mock_agent_tool + + +def test_list_tools_sync_prefix_override_with_empty_string(mock_transport): + """Test that list_tools_sync can override constructor prefix with empty string.""" + # Create a mock MCP tool + mock_mcp_tool = MagicMock() + mock_mcp_tool.name = "original_tool" + + # Client with constructor prefix + client = MCPClient(mock_transport, prefix="constructor") + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [mock_mcp_tool] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation + mock_agent_tool = MagicMock(spec=MCPAgentTool) + mock_agent_tool.tool_name = "original_tool" + mock_agent_tool_class.return_value = mock_agent_tool + + # Call with empty string prefix (should override constructor default) + result = client.list_tools_sync(prefix="") + + # Should use no prefix (empty string overrides constructor) + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client) + + assert len(result) == 1 + assert result[0] is mock_agent_tool + + +def test_list_tools_sync_prefix_uses_constructor_default_when_none(mock_transport): + """Test that list_tools_sync uses constructor prefix when None is passed.""" + # Create a mock MCP tool + mock_mcp_tool = MagicMock() + mock_mcp_tool.name = "original_tool" + + # Client with constructor prefix + client = MCPClient(mock_transport, prefix="constructor") + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [mock_mcp_tool] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation + mock_agent_tool = MagicMock(spec=MCPAgentTool) + mock_agent_tool.tool_name = "constructor_original_tool" + mock_agent_tool_class.return_value = mock_agent_tool + + # Call with None prefix (should use constructor default) + result = client.list_tools_sync(prefix=None) + + # Should use constructor prefix + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client, name_override="constructor_original_tool") + + assert len(result) == 1 + assert result[0] is mock_agent_tool + + +def test_list_tools_sync_tool_filters_override_constructor_default(mock_transport): + """Test that list_tools_sync can override constructor tool_filters.""" + # Create mock tools + tool1 = create_mock_tool("allowed_tool") + tool2 = create_mock_tool("rejected_tool") + + # Client with constructor filters that would allow both + constructor_filters: ToolFilters = {"allowed": ["allowed_tool", "rejected_tool"]} + client = MCPClient(mock_transport, tool_filters=constructor_filters) + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [MagicMock(name="allowed_tool"), MagicMock(name="rejected_tool")] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation to return our test tools + mock_agent_tool_class.side_effect = [tool1, tool2] + + # Override filters to only allow one tool + override_filters: ToolFilters = {"allowed": ["allowed_tool"]} + result = client.list_tools_sync(tool_filters=override_filters) + + # Should only include the allowed tool based on override filters + assert len(result) == 1 + assert result[0] is tool1 + + +def test_list_tools_sync_tool_filters_override_with_empty_dict(mock_transport): + """Test that list_tools_sync can override constructor filters with empty dict.""" + # Create mock tools + tool1 = create_mock_tool("tool1") + tool2 = create_mock_tool("tool2") + + # Client with constructor filters that would reject tools + constructor_filters: ToolFilters = {"rejected": ["tool1", "tool2"]} + client = MCPClient(mock_transport, tool_filters=constructor_filters) + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [MagicMock(name="tool1"), MagicMock(name="tool2")] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation to return our test tools + mock_agent_tool_class.side_effect = [tool1, tool2] + + # Override with empty filters (should allow all tools) + result = client.list_tools_sync(tool_filters={}) + + # Should include both tools since empty filters allow everything + assert len(result) == 2 + assert result[0] is tool1 + assert result[1] is tool2 + + +def test_list_tools_sync_tool_filters_uses_constructor_default_when_none(mock_transport): + """Test that list_tools_sync uses constructor filters when None is passed.""" + # Create mock tools + tool1 = create_mock_tool("allowed_tool") + tool2 = create_mock_tool("rejected_tool") + + # Client with constructor filters + constructor_filters: ToolFilters = {"allowed": ["allowed_tool"]} + client = MCPClient(mock_transport, tool_filters=constructor_filters) + client._tool_provider_started = True + + # Mock the session active state + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + client._background_thread = mock_thread + + with ( + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + # Mock the MCP server response + mock_list_tools_result = MagicMock() + mock_list_tools_result.tools = [MagicMock(name="allowed_tool"), MagicMock(name="rejected_tool")] + mock_list_tools_result.nextCursor = None + + mock_future = MagicMock() + mock_future.result.return_value = mock_list_tools_result + mock_invoke.return_value = mock_future + + # Mock MCPAgentTool creation to return our test tools + mock_agent_tool_class.side_effect = [tool1, tool2] + + # Call with None filters (should use constructor default) + result = client.list_tools_sync(tool_filters=None) + + # Should only include allowed tool based on constructor filters + assert len(result) == 1 + assert result[0] is tool1 + + +def test_list_tools_sync_combined_prefix_and_filter_overrides(mock_transport): + """Test that list_tools_sync can override both prefix and filters simultaneously.""" + # Client with constructor defaults + constructor_filters: ToolFilters = {"allowed": ["echo_tool", "other_tool"]} + client = MCPClient(mock_transport, tool_filters=constructor_filters, prefix="constructor") + + # Create mock tools + mock_echo_tool = MagicMock() + mock_echo_tool.name = "echo_tool" + mock_other_tool = MagicMock() + mock_other_tool.name = "other_tool" + + # Mock the MCP response + mock_result = MagicMock() + mock_result.tools = [mock_echo_tool, mock_other_tool] + mock_result.nextCursor = None + + with ( + patch.object(client, "_is_session_active", return_value=True), + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_future = MagicMock() + mock_future.result.return_value = mock_result + mock_invoke.return_value = mock_future + + # Create mock agent tools + mock_agent_tool1 = MagicMock() + mock_agent_tool1.mcp_tool = mock_echo_tool + mock_agent_tool2 = MagicMock() + mock_agent_tool2.mcp_tool = mock_other_tool + mock_agent_tool_class.side_effect = [mock_agent_tool1, mock_agent_tool2] + + # Override both prefix and filters + override_filters: ToolFilters = {"allowed": ["echo_tool"]} + result = client.list_tools_sync(prefix="override", tool_filters=override_filters) + + # Verify prefix override: should use "override" not "constructor" + calls = mock_agent_tool_class.call_args_list + assert len(calls) == 2 + + # First tool should have override prefix + args1, kwargs1 = calls[0] + assert args1 == (mock_echo_tool, client) + assert kwargs1 == {"name_override": "override_echo_tool"} + + # Second tool should have override prefix + args2, kwargs2 = calls[1] + assert args2 == (mock_other_tool, client) + assert kwargs2 == {"name_override": "override_other_tool"} + + # Verify filter override: should only include echo_tool based on override filters + assert len(result) == 1 + assert result[0] is mock_agent_tool1 + + +def test_list_tools_sync_direct_usage_without_constructor_defaults(mock_transport): + """Test direct usage of list_tools_sync without constructor defaults.""" + # Client without constructor defaults + client = MCPClient(mock_transport) + + # Create mock tools + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + + # Mock the MCP response + mock_result = MagicMock() + mock_result.tools = [mock_tool1, mock_tool2] + mock_result.nextCursor = None + + with ( + patch.object(client, "_is_session_active", return_value=True), + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_future = MagicMock() + mock_future.result.return_value = mock_result + mock_invoke.return_value = mock_future + + # Create mock agent tools + mock_agent_tool1 = MagicMock() + mock_agent_tool1.mcp_tool = mock_tool1 + mock_agent_tool2 = MagicMock() + mock_agent_tool2.mcp_tool = mock_tool2 + mock_agent_tool_class.side_effect = [mock_agent_tool1, mock_agent_tool2] + + # Direct usage with explicit parameters + filters: ToolFilters = {"allowed": ["tool1"]} + result = client.list_tools_sync(prefix="direct", tool_filters=filters) + + # Verify prefix is applied + calls = mock_agent_tool_class.call_args_list + assert len(calls) == 2 + + # Should create tools with direct prefix + args1, kwargs1 = calls[0] + assert args1 == (mock_tool1, client) + assert kwargs1 == {"name_override": "direct_tool1"} + + args2, kwargs2 = calls[1] + assert args2 == (mock_tool2, client) + assert kwargs2 == {"name_override": "direct_tool2"} + + # Verify filtering: should only include tool1 + assert len(result) == 1 + assert result[0] is mock_agent_tool1 + + +def test_list_tools_sync_regex_filter_override(mock_transport): + """Test list_tools_sync with regex filter override.""" + # Client without constructor filters + client = MCPClient(mock_transport) + + # Create mock tools + mock_echo_tool = MagicMock() + mock_echo_tool.name = "echo_command" + mock_list_tool = MagicMock() + mock_list_tool.name = "list_files" + + # Mock the MCP response + mock_result = MagicMock() + mock_result.tools = [mock_echo_tool, mock_list_tool] + mock_result.nextCursor = None + + with ( + patch.object(client, "_is_session_active", return_value=True), + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_future = MagicMock() + mock_future.result.return_value = mock_result + mock_invoke.return_value = mock_future + + # Create mock agent tools + mock_agent_tool1 = MagicMock() + mock_agent_tool1.mcp_tool = mock_echo_tool + mock_agent_tool2 = MagicMock() + mock_agent_tool2.mcp_tool = mock_list_tool + mock_agent_tool_class.side_effect = [mock_agent_tool1, mock_agent_tool2] + + # Use regex filter to match only echo tools + regex_filters: ToolFilters = {"allowed": [re.compile(r"echo_.*")]} + result = client.list_tools_sync(tool_filters=regex_filters) + + # Should create both tools + assert mock_agent_tool_class.call_count == 2 + + # Should only include echo tool (regex matches "echo_command") + assert len(result) == 1 + assert result[0] is mock_agent_tool1 + + +def test_list_tools_sync_callable_filter_override(mock_transport): + """Test list_tools_sync with callable filter override.""" + # Client without constructor filters + client = MCPClient(mock_transport) + + # Create mock tools + mock_short_tool = MagicMock() + mock_short_tool.name = "short" + mock_long_tool = MagicMock() + mock_long_tool.name = "very_long_tool_name" + + # Mock the MCP response + mock_result = MagicMock() + mock_result.tools = [mock_short_tool, mock_long_tool] + mock_result.nextCursor = None + + with ( + patch.object(client, "_is_session_active", return_value=True), + patch.object(client, "_invoke_on_background_thread") as mock_invoke, + patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, + ): + mock_future = MagicMock() + mock_future.result.return_value = mock_result + mock_invoke.return_value = mock_future + + # Create mock agent tools + mock_agent_tool1 = MagicMock() + mock_agent_tool1.mcp_tool = mock_short_tool + mock_agent_tool2 = MagicMock() + mock_agent_tool2.mcp_tool = mock_long_tool + mock_agent_tool_class.side_effect = [mock_agent_tool1, mock_agent_tool2] + + # Use callable filter for short names only + def short_names_only(tool) -> bool: + return len(tool.mcp_tool.name) <= 10 + + callable_filters: ToolFilters = {"allowed": [short_names_only]} + result = client.list_tools_sync(tool_filters=callable_filters) + + # Should create both tools + assert mock_agent_tool_class.call_count == 2 + + # Should only include short tool (name length <= 10) + assert len(result) == 1 + assert result[0] is mock_agent_tool1 diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py index 7fdb4e07c..6fb2d6463 100644 --- a/tests/strands/tools/test_registry_tool_provider.py +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -24,16 +24,16 @@ def __init__(self, tools=None, cleanup_error=None): async def load_tools(self): return self._tools - async def cleanup(self): + def cleanup(self): self.cleanup_called = True if self._cleanup_error: raise self._cleanup_error - async def add_consumer(self, consumer_id): + def add_consumer(self, consumer_id): self.add_consumer_called = True self.add_consumer_id = consumer_id - async def remove_consumer(self, consumer_id): + def remove_consumer(self, consumer_id): self.remove_consumer_called = True self.remove_consumer_id = consumer_id @@ -228,17 +228,17 @@ async def load_tools(self): self.load_tools_called = True return [mock_tool] - async def add_consumer(self, consumer_id): + def add_consumer(self, consumer_id): self.add_consumer_called = True self.add_consumer_id = consumer_id - async def remove_consumer(self, consumer_id): + def remove_consumer(self, consumer_id): pass provider = TestProvider() registry = ToolRegistry() - # Process the provider - this should call both methods in same async context + # Process the provider - this should call both methods tool_names = registry.process_tools([provider]) # Verify both methods were called @@ -250,8 +250,7 @@ async def remove_consumer(self, consumer_id): assert "test_tool" in tool_names assert provider in registry._tool_providers - @pytest.mark.asyncio - async def test_registry_cleanup(self): + def test_registry_cleanup(self): """Test that registry cleanup calls remove_consumer on all providers.""" provider1 = MockToolProvider() provider2 = MockToolProvider() @@ -259,14 +258,13 @@ async def test_registry_cleanup(self): registry = ToolRegistry() registry._tool_providers = [provider1, provider2] - await registry.cleanup_async() + registry.cleanup() # Verify both providers had remove_consumer called assert provider1.remove_consumer_called assert provider2.remove_consumer_called - @pytest.mark.asyncio - async def test_registry_cleanup_with_provider_consumer_removal(self): + def test_registry_cleanup_with_provider_consumer_removal(self): """Test that cleanup removes provider consumers correctly.""" class TestProvider(ToolProvider): @@ -277,10 +275,10 @@ def __init__(self): async def load_tools(self): return [] - async def add_consumer(self, consumer_id): + def add_consumer(self, consumer_id): pass - async def remove_consumer(self, consumer_id): + def remove_consumer(self, consumer_id): self.remove_consumer_called = True self.remove_consumer_id = consumer_id @@ -289,7 +287,7 @@ async def remove_consumer(self, consumer_id): registry._tool_providers = [provider] # Call cleanup - await registry.cleanup_async() + registry.cleanup() # Verify remove_consumer was called with correct ID assert provider.remove_consumer_called From 1a99e13892111330633e55957c589b8a8cdd469f Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 21 Oct 2025 13:48:00 -0400 Subject: [PATCH 19/35] simplify run_async --- src/strands/_async.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/strands/_async.py b/src/strands/_async.py index 976487c37..8ea9098fe 100644 --- a/src/strands/_async.py +++ b/src/strands/_async.py @@ -20,12 +20,6 @@ def run_async(async_func: Callable[[], Awaitable[T]]) -> T: The result of the async function """ - async def execute_async() -> T: - return await async_func() - - def execute() -> T: - return asyncio.run(execute_async()) - with ThreadPoolExecutor() as executor: - future = executor.submit(execute) + future = executor.submit(asyncio.run, async_func()) return future.result() From ccebe72025f50673fb2323ad473a561886d6a449 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 21 Oct 2025 13:49:32 -0400 Subject: [PATCH 20/35] linting --- src/strands/_async.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/strands/_async.py b/src/strands/_async.py index 8ea9098fe..b24010fb7 100644 --- a/src/strands/_async.py +++ b/src/strands/_async.py @@ -19,7 +19,6 @@ def run_async(async_func: Callable[[], Awaitable[T]]) -> T: Returns: The result of the async function """ - with ThreadPoolExecutor() as executor: future = executor.submit(asyncio.run, async_func()) return future.result() From a0d5d700d4f8a7e919397418af96cd1543dd2b19 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 21 Oct 2025 14:47:33 -0400 Subject: [PATCH 21/35] clean --- src/strands/_async.py | 9 +- src/strands/agent/agent.py | 125 +++++++++++++++--- .../experimental/tools/tool_provider.py | 8 -- src/strands/tools/mcp/mcp_client.py | 18 +-- src/strands/tools/registry.py | 27 ++-- .../tools/test_registry_tool_provider.py | 34 +++++ 6 files changed, 168 insertions(+), 53 deletions(-) diff --git a/src/strands/_async.py b/src/strands/_async.py index b24010fb7..976487c37 100644 --- a/src/strands/_async.py +++ b/src/strands/_async.py @@ -19,6 +19,13 @@ def run_async(async_func: Callable[[], Awaitable[T]]) -> T: Returns: The result of the async function """ + + async def execute_async() -> T: + return await async_func() + + def execute() -> T: + return asyncio.run(execute_async()) + with ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, async_func()) + future = executor.submit(execute) return future.result() diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index fe5a5c7b4..3d405ed0e 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -497,19 +497,19 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. - If you pass in a prompt, it will be used temporarily without adding it to the conversation history. - If you don't pass in a prompt, it will use only the existing conversation history to respond. + If you pass in a prompt, it will be used temporarily without adding it to the conversation history. + If you don't pass in a prompt, it will use only the existing conversation history to respond. - For smaller models, you may want to use the optional prompt to add additional instructions to explicitly - instruct the model to output the structured data. + For smaller models, you may want to use the optional prompt to add additional instructions to explicitly + instruct the model to output the structured data. Args: - output_model: The output model (a JSON schema written as a Pydantic BaseModel) - that the agent will use when responding. - prompt: The prompt to use for the agent (will not be added to conversation history). + output_model: The output model (a JSON schema written as a Pydantic BaseModel) + that the agent will use when responding. + prompt: The prompt to use for the agent (will not be added to conversation history). Raises: - ValueError: If no conversation history or prompt is provided. + ValueError: If no conversation history or prompt is provided. - """ if self._interrupt_state.activated: @@ -578,9 +578,7 @@ def cleanup(self) -> None: def _cleanup_on_finalize(agent_uuid: str, agent_id: str) -> None: """Static cleanup method called by weakref.finalize. - WHY SYNCHRONOUS CLEANUP IS CRITICAL: - - weakref.finalize is safer than __del__ because: + weakref.finalize is used over __del__ because: 1. Runs AFTER garbage collection completes, not during (no GIL deadlocks) 2. Cannot access 'self' so can't call methods that might block (no run_async deadlocks) 3. Executes in a controlled environment where Python isn't in restricted GC state @@ -610,7 +608,42 @@ def _cleanup_on_finalize(agent_uuid: str, agent_id: str) -> None: async def stream_async( self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> AsyncIterator[Any]: - """Process a natural language prompt and yield events as an async iterator.""" + """Process a natural language prompt and yield events as an async iterator. + + This method provides an asynchronous interface for streaming agent events with multiple input patterns: + - String input: Simple text input + - ContentBlock list: Multi-modal content blocks + - Message list: Complete messages with roles + - No input: Use existing conversation history + + Args: + prompt: User input in various formats: + - str: Simple text input + - list[ContentBlock]: Multi-modal content blocks + - list[Message]: Complete messages with roles + - None: Use existing conversation history + invocation_state: Additional parameters to pass through the event loop. + **kwargs: Additional parameters to pass to the event loop.[Deprecating] + + Yields: + An async iterator that yields events. Each event is a dictionary containing + information about the current state of processing, such as: + + - data: Text content being generated + - complete: Whether this is the final chunk + - current_tool_use: Information about tools being executed + - And other event data provided by the callback handler + + Raises: + Exception: Any exceptions from the agent invocation will be propagated to the caller. + + Example: + ```python + async for event in agent.stream_async("Analyze this data"): + if "data" in event: + yield event["data"] + ``` + """ self._resume_interrupt(prompt) merged_state = {} @@ -655,7 +688,14 @@ async def stream_async( raise def _resume_interrupt(self, prompt: AgentInput) -> None: - """Configure the interrupt state if resuming from an interrupt event.""" + """Configure the interrupt state if resuming from an interrupt event. + + Args: + prompt: User responses if resuming from interrupt. + + Raises: + TypeError: If in interrupt state but user did not provide responses. + """ if not self._interrupt_state.activated: return @@ -680,7 +720,15 @@ def _resume_interrupt(self, prompt: AgentInput) -> None: self._interrupt_state.interrupts[interrupt_id].response = interrupt_response async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: - """Execute the agent's event loop with the given message and parameters.""" + """Execute the agent's event loop with the given message and parameters. + + Args: + messages: The input messages to add to the conversation. + invocation_state: Additional parameters to pass to the event loop. + + Yields: + Events from the event loop cycle. + """ self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) try: @@ -712,7 +760,15 @@ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: - """Execute the event loop cycle with retry logic for context window limits.""" + """Execute the event loop cycle with retry logic for context window limits. + + This internal method handles the execution of the event loop cycle and implements + retry logic for handling context window overflow exceptions by reducing the + conversation context and retrying. + + Yields: + Events of the loop cycle. + """ # Add `Agent` to invocation_state to keep backwards-compatibility invocation_state["agent"] = self @@ -774,7 +830,20 @@ def _record_tool_execution( tool_result: ToolResult, user_message_override: Optional[str], ) -> None: - """Record a tool execution in the message history.""" + """Record a tool execution in the message history. + + Creates a sequence of messages that represent the tool execution: + + 1. A user message describing the tool call + 2. An assistant message with the tool use + 3. A user message with the tool result + 4. An assistant message acknowledging the tool call + + Args: + tool: The tool call information. + tool_result: The result returned by the tool. + user_message_override: Optional custom message to include. + """ # Filter tool input parameters to only include those defined in tool spec filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) @@ -821,7 +890,11 @@ def _record_tool_execution( self._append_message(assistant_msg) def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: - """Starts a trace span for the agent.""" + """Starts a trace span for the agent. + + Args: + messages: The input messages. + """ model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None return self.tracer.start_agent_span( messages=messages, @@ -837,7 +910,13 @@ def _end_agent_trace_span( response: Optional[AgentResult] = None, error: Optional[Exception] = None, ) -> None: - """Ends a trace span for the agent.""" + """Ends a trace span for the agent. + + Args: + span: The span to end. + response: Response to record as a trace attribute. + error: Error to record as a trace attribute. + """ if self.trace_span: trace_attributes: dict[str, Any] = { "span": self.trace_span, @@ -851,7 +930,15 @@ def _end_agent_trace_span( self.tracer.end_agent_span(**trace_attributes) def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: - """Filter input parameters to only include those defined in the tool specification.""" + """Filter input parameters to only include those defined in the tool specification. + + Args: + tool_name: Name of the tool to get specification for + input_params: Original input parameters + + Returns: + Filtered parameters containing only those defined in tool spec + """ all_tools_config = self.tool_registry.get_all_tools_config() tool_spec = all_tools_config.get(tool_name) diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py index 2eac46492..03f4f7aa3 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/experimental/tools/tool_provider.py @@ -30,10 +30,6 @@ async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]: def add_consumer(self, id: Any, **kwargs: Any) -> None: """Add a consumer to this tool provider. - This method is synchronous to avoid deadlocks during garbage collection. - When Agent finalizers run during GC, they need to clean up tool providers - without using run_async() which can deadlock due to GIL restrictions. - Args: id: Unique identifier for the consumer. **kwargs: Additional arguments for future compatibility. @@ -44,10 +40,6 @@ def add_consumer(self, id: Any, **kwargs: Any) -> None: def remove_consumer(self, id: Any, **kwargs: Any) -> None: """Remove a consumer from this tool provider. - This method is synchronous to avoid deadlocks during garbage collection. - When Agent finalizers run during GC, they need to clean up tool providers - without using run_async() which can deadlock due to GIL restrictions. - Provider may clean up resources when no consumers remain. Args: diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 65695b609..67ded10bb 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -44,7 +44,7 @@ class _ToolFilterCallback(Protocol): def __call__(self, tool: AgentTool, **kwargs: Any) -> bool: ... -_ToolFilterPattern = str | Pattern[str] | _ToolFilterCallback +_ToolMatcher = str | Pattern[str] | _ToolFilterCallback class ToolFilters(TypedDict, total=False): @@ -55,8 +55,8 @@ class ToolFilters(TypedDict, total=False): 2. Tools matching 'rejected' patterns are then excluded """ - allowed: list[_ToolFilterPattern] - rejected: list[_ToolFilterPattern] + allowed: list[_ToolMatcher] + rejected: list[_ToolMatcher] MIME_TO_FORMAT: Dict[str, ImageFormat] = { @@ -91,8 +91,8 @@ def __init__( transport_callable: Callable[[], MCPTransport], *, startup_timeout: int = 30, - tool_filters: Optional[ToolFilters] = None, - prefix: Optional[str] = None, + tool_filters: ToolFilters | None = None, + prefix: str | None = None, ): """Initialize a new MCP Server connection. @@ -320,9 +320,9 @@ async def _set_close_event() -> None: def list_tools_sync( self, - pagination_token: Optional[str] = None, - prefix: Optional[str] = None, - tool_filters: Optional[ToolFilters] = None, + pagination_token: str | None = None, + prefix: str | None = None, + tool_filters: ToolFilters | None = None, ) -> PaginatedList[MCPAgentTool]: """Synchronously retrieves the list of available tools from the MCP server. @@ -709,7 +709,7 @@ def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: Optiona return True - def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolFilterPattern]) -> bool: + def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolMatcher]) -> bool: """Check if tool matches any of the given patterns.""" for pattern in patterns: if callable(pattern): diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 162c84c63..f0fa0d655 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -132,7 +132,7 @@ async def get_tools() -> Sequence[AgentTool]: return await tool.load_tools() provider_tools = run_async(get_tools) - tool.add_consumer(self._registry_id) # Now sync + tool.add_consumer(self._registry_id) for provider_tool in provider_tools: self.register_tool(provider_tool) @@ -661,23 +661,18 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: return tools def cleanup(self, **kwargs: Any) -> None: - """Synchronously clean up all tool providers in this registry. - - This method is safe to call from Agent finalizers during garbage collection - because it avoids run_async() which can deadlock when the GIL is held. - - The synchronous approach prevents: - 1. GC deadlocks - run_async() creates ThreadPoolExecutor during GC - 2. Interpreter shutdown hangs - ThreadPoolExecutor creation fails during shutdown - 3. Finalizer threading issues - weakref.finalize runs in restricted environment - """ + """Synchronously clean up all tool providers in this registry.""" + # Attempt cleanup of all providers even if one fails to minimize resource leakage during garbage collection + exceptions = [] for provider in self._tool_providers: try: - provider.remove_consumer(self._registry_id) # Now sync + provider.remove_consumer(self._registry_id) logger.debug("provider=<%s> | removed provider consumer", type(provider).__name__) except Exception as e: - logger.warning( - "provider=<%s>, error=<%s> | failed to remove provider consumer", - type(provider).__name__, - e, + exceptions.append(e) + logger.error( + "provider=<%s>, error=<%s> | failed to remove provider consumer", type(provider).__name__, e ) + + if exceptions: + raise exceptions[0] diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py index 6fb2d6463..fdf4abb0a 100644 --- a/tests/strands/tools/test_registry_tool_provider.py +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -36,6 +36,8 @@ def add_consumer(self, consumer_id): def remove_consumer(self, consumer_id): self.remove_consumer_called = True self.remove_consumer_id = consumer_id + if self._cleanup_error: + raise self._cleanup_error @pytest.fixture @@ -292,3 +294,35 @@ def remove_consumer(self, consumer_id): # Verify remove_consumer was called with correct ID assert provider.remove_consumer_called assert provider.remove_consumer_id == registry._registry_id + + def test_registry_cleanup_raises_exception_on_provider_error(self): + """Test that cleanup raises exception when provider removal fails.""" + provider1 = MockToolProvider(cleanup_error=RuntimeError("Provider cleanup failed")) + provider2 = MockToolProvider() + + registry = ToolRegistry() + registry._tool_providers = [provider1, provider2] + + # Cleanup should raise the exception from first provider but still attempt cleanup of all + with pytest.raises(RuntimeError, match="Provider cleanup failed"): + registry.cleanup() + + # Both providers should have had remove_consumer called + assert provider1.remove_consumer_called + assert provider2.remove_consumer_called + + def test_registry_cleanup_raises_first_exception_on_multiple_provider_errors(self): + """Test that cleanup raises first exception when multiple providers fail but attempts all.""" + provider1 = MockToolProvider(cleanup_error=RuntimeError("Provider 1 failed")) + provider2 = MockToolProvider(cleanup_error=ValueError("Provider 2 failed")) + + registry = ToolRegistry() + registry._tool_providers = [provider1, provider2] + + # Cleanup should raise first exception but still attempt cleanup of all + with pytest.raises(RuntimeError, match="Provider 1 failed"): + registry.cleanup() + + # Both providers should have had remove_consumer called + assert provider1.remove_consumer_called + assert provider2.remove_consumer_called From c88d5859402afae4b216257677923b72d7b747c5 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 21 Oct 2025 14:53:53 -0400 Subject: [PATCH 22/35] remove timeout --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 214e7c45f..b542c7481 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,7 +134,6 @@ dependencies = [ "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.3.0", "pytest-xdist>=3.0.0,<4.0.0", - "pytest-timeout>=2.0.0,<3.0.0", "moto>=5.1.0,<6.0.0", ] @@ -142,7 +141,7 @@ dependencies = [ python = ["3.13", "3.12", "3.11", "3.10"] [tool.hatch.envs.hatch-test.scripts] -run = "pytest{env:HATCH_TEST_ARGS:} --timeout=10 {args}" # Run with: hatch test +run = "pytest{env:HATCH_TEST_ARGS:} {args}" # Run with: hatch test run-cov = "pytest{env:HATCH_TEST_ARGS:} {args} --cov --cov-config=pyproject.toml --cov-report html --cov-report xml {args}" # Run with: hatch test -c cov-combine = [] cov-report = [] From beed7c1880813ff38be30c77d7fcaebd58089211 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 22 Oct 2025 09:30:34 -0400 Subject: [PATCH 23/35] fix integ test --- tests_integ/mcp/test_mcp_tool_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_integ/mcp/test_mcp_tool_provider.py b/tests_integ/mcp/test_mcp_tool_provider.py index acd32fbfe..7914bb326 100644 --- a/tests_integ/mcp/test_mcp_tool_provider.py +++ b/tests_integ/mcp/test_mcp_tool_provider.py @@ -96,7 +96,7 @@ def test_mcp_client_tool_provider_reuse(): agent1.cleanup() # Agent 2 should still be able to use the tool - result2 = agent2.tool.ref_echo(to_echo="Agent 2 Test") + result2 = agent2.tool.shared_echo(to_echo="Agent 2 Test") assert "Agent 2 Test" in str(result2) agent2.cleanup() From 9c5b91e233fc7a043a3b92246f63472065be8409 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 22 Oct 2025 10:47:57 -0400 Subject: [PATCH 24/35] rebase fix --- tests/strands/agent/test_agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 46c48b004..251effc96 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -4,6 +4,7 @@ import os import textwrap import unittest.mock +import warnings from uuid import uuid4 import pytest From 4bfe6966d5522ef29e78da7ae4a95cc2ef2b2390 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 22 Oct 2025 11:26:07 -0400 Subject: [PATCH 25/35] fix circular dep --- src/strands/agent/agent.py | 7 +++++-- src/strands/tools/registry.py | 10 ++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 3d405ed0e..177e6931e 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -16,6 +16,7 @@ import warnings import weakref from typing import ( + TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, @@ -34,7 +35,9 @@ from .. import _identifier from .._async import run_async from ..event_loop.event_loop import event_loop_cycle -from ..experimental.tools import ToolProvider + +if TYPE_CHECKING: + from ..experimental.tools import ToolProvider from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, @@ -231,7 +234,7 @@ def __init__( self, model: Union[Model, str, None] = None, messages: Optional[Messages] = None, - tools: Optional[list[Union[str, dict[str, str], ToolProvider, Any]]] = None, + tools: Optional[list[Union[str, dict[str, str], "ToolProvider", Any]]] = None, system_prompt: Optional[str] = None, callback_handler: Optional[ Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index f0fa0d655..e654716bd 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -13,18 +13,20 @@ from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence from typing_extensions import TypedDict, cast from strands.tools.decorator import DecoratedFunctionTool from .._async import run_async -from ..experimental.tools import ToolProvider from ..types.tools import AgentTool, ToolSpec from .loader import load_tool_from_string, load_tools_from_module from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec +if TYPE_CHECKING: + from ..experimental.tools import ToolProvider + logger = logging.getLogger(__name__) @@ -39,7 +41,7 @@ def __init__(self) -> None: self.registry: Dict[str, AgentTool] = {} self.dynamic_tools: Dict[str, AgentTool] = {} self.tool_config: Optional[Dict[str, Any]] = None - self._tool_providers: List[ToolProvider] = [] + self._tool_providers: List["ToolProvider"] = [] self._registry_id = str(uuid.uuid4()) def process_tools(self, tools: List[Any]) -> List[str]: @@ -125,7 +127,7 @@ def add_tool(tool: Any) -> None: add_tool(t) # Case 5: ToolProvider - elif isinstance(tool, ToolProvider): + elif hasattr(tool, 'load_tools') and hasattr(tool, 'add_consumer') and hasattr(tool, 'remove_consumer'): self._tool_providers.append(tool) async def get_tools() -> Sequence[AgentTool]: From f33f8ec59588cced12cca637aafbd39702726bca Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 22 Oct 2025 11:32:09 -0400 Subject: [PATCH 26/35] linting --- src/strands/tools/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index e654716bd..696877f8d 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -127,7 +127,7 @@ def add_tool(tool: Any) -> None: add_tool(t) # Case 5: ToolProvider - elif hasattr(tool, 'load_tools') and hasattr(tool, 'add_consumer') and hasattr(tool, 'remove_consumer'): + elif isinstance(tool, ToolProvider): self._tool_providers.append(tool) async def get_tools() -> Sequence[AgentTool]: From cb9d9541dbcebcd1ee022b8e8afe47161c7e60df Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 22 Oct 2025 11:34:45 -0400 Subject: [PATCH 27/35] fix linting --- src/strands/tools/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 696877f8d..e3be91370 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -127,7 +127,7 @@ def add_tool(tool: Any) -> None: add_tool(t) # Case 5: ToolProvider - elif isinstance(tool, ToolProvider): + elif isinstance(tool, "ToolProvider"): self._tool_providers.append(tool) async def get_tools() -> Sequence[AgentTool]: From 4fc00a82eed46cb7b5e3af59a3842f8aec59e589 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 22 Oct 2025 11:41:43 -0400 Subject: [PATCH 28/35] fix imports --- src/strands/experimental/agent_config.py | 7 ++++--- src/strands/tools/registry.py | 10 ++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py index d08f89cf9..f65afb57d 100644 --- a/src/strands/experimental/agent_config.py +++ b/src/strands/experimental/agent_config.py @@ -18,8 +18,6 @@ import jsonschema from jsonschema import ValidationError -from ..agent import Agent - # JSON Schema for agent configuration AGENT_CONFIG_SCHEMA = { "$schema": "http://json-schema.org/draft-07/schema#", @@ -53,7 +51,7 @@ _VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA) -def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Agent: +def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Any: """Create an Agent from a configuration file or dictionary. This function supports tools that can be loaded declaratively (file paths, module names, @@ -134,5 +132,8 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A # Override with any additional kwargs provided agent_kwargs.update(kwargs) + # Import Agent at runtime to avoid circular imports + from ..agent import Agent + # Create and return Agent return Agent(**agent_kwargs) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index e3be91370..f0fa0d655 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -13,20 +13,18 @@ from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence +from typing import Any, Dict, Iterable, List, Optional, Sequence from typing_extensions import TypedDict, cast from strands.tools.decorator import DecoratedFunctionTool from .._async import run_async +from ..experimental.tools import ToolProvider from ..types.tools import AgentTool, ToolSpec from .loader import load_tool_from_string, load_tools_from_module from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec -if TYPE_CHECKING: - from ..experimental.tools import ToolProvider - logger = logging.getLogger(__name__) @@ -41,7 +39,7 @@ def __init__(self) -> None: self.registry: Dict[str, AgentTool] = {} self.dynamic_tools: Dict[str, AgentTool] = {} self.tool_config: Optional[Dict[str, Any]] = None - self._tool_providers: List["ToolProvider"] = [] + self._tool_providers: List[ToolProvider] = [] self._registry_id = str(uuid.uuid4()) def process_tools(self, tools: List[Any]) -> List[str]: @@ -127,7 +125,7 @@ def add_tool(tool: Any) -> None: add_tool(t) # Case 5: ToolProvider - elif isinstance(tool, "ToolProvider"): + elif isinstance(tool, ToolProvider): self._tool_providers.append(tool) async def get_tools() -> Sequence[AgentTool]: From c49570854287a9f282f6e30d7d43ebf48fa39007 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 23 Oct 2025 12:29:10 -0400 Subject: [PATCH 29/35] back to __del__ --- src/strands/agent/agent.py | 73 +-------- .../experimental/tools/tool_provider.py | 11 +- src/strands/tools/mcp/mcp_client.py | 11 +- src/strands/tools/registry.py | 4 +- tests/strands/agent/test_agent.py | 7 + tests/strands/tools/test_registry.py | 145 ++++++++++++++++++ 6 files changed, 176 insertions(+), 75 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 177e6931e..5dc133254 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -12,9 +12,7 @@ import json import logging import random -import uuid import warnings -import weakref from typing import ( TYPE_CHECKING, Any, @@ -88,26 +86,6 @@ class _DefaultCallbackHandlerSentinel: _DEFAULT_AGENT_NAME = "Strands Agents" _DEFAULT_AGENT_ID = "default" -"""Global private store for agent cleanup - maps UUID to ToolRegistry. - -Why use weakref.finalize with a global store? - -MCP clients spawn background threads that must be properly cleaned up to prevent -thread leaks. In __del__, the agent's references to these threads may already be -deleted while the threads still exist, making cleanup impossible. The threads -then continue running, consuming resources and causing interpreter shutdown hangs. - -weakref.finalize cannot access 'self' or instance attributes because: -1. Finalizers run AFTER the object is garbage collected -2. Accessing 'self' would create a circular reference preventing GC -3. Instance attributes may be in an undefined state during finalization -4. __del__ methods can access 'self' but are unreliable (may never run) - -The global store enables safe cleanup without circular references -and reliable execution during interpreter shutdown. -""" -_AGENT_CLEANUP_STORE: dict[str, "ToolRegistry"] = {} - class Agent: """Core Agent interface. @@ -360,9 +338,6 @@ def __init__( else: self.state = AgentState() - # Track cleanup state - self._cleanup_called = False - self.tool_caller = Agent.ToolCaller(self) self.hooks = HookRegistry() @@ -381,10 +356,6 @@ def __init__( self.hooks.add_hook(hook) self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) - self._agent_uuid = str(uuid.uuid4()) - _AGENT_CLEANUP_STORE[self._agent_uuid] = self.tool_registry - self._finalizer = weakref.finalize(self, self._cleanup_on_finalize, self._agent_uuid, self.agent_id) - @property def tool(self) -> ToolCaller: """Call tool as a function. @@ -571,42 +542,14 @@ def cleanup(self) -> None: Note: This method uses a "belt and braces" approach with automatic cleanup through finalizers as a fallback, but explicit cleanup is recommended. """ - if self._cleanup_called: - return - - self._cleanup_on_finalize(self._agent_uuid, self.agent_id) - self._cleanup_called = True - - @staticmethod - def _cleanup_on_finalize(agent_uuid: str, agent_id: str) -> None: - """Static cleanup method called by weakref.finalize. - - weakref.finalize is used over __del__ because: - 1. Runs AFTER garbage collection completes, not during (no GIL deadlocks) - 2. Cannot access 'self' so can't call methods that might block (no run_async deadlocks) - 3. Executes in a controlled environment where Python isn't in restricted GC state - 4. More reliable execution timing - __del__ can be delayed or skipped entirely - 5. No circular reference issues that can prevent __del__ from being called - 6. Uses global store to avoid copying complex objects at registration time - - SYNCHRONOUS CONSUMER MANAGEMENT PREVENTS: - - GC Deadlocks: run_async() creates ThreadPoolExecutor while GIL is held during GC - - Interpreter Shutdown Hangs: ThreadPoolExecutor creation fails during shutdown - - Finalizer Threading Issues: weakref.finalize runs in restricted threading environment - - Resource Leaks: Ensures MCP background threads are properly stopped - - The synchronous approach uses MCPClient.stop() which safely: - - Signals background thread via asyncio.Event - - Waits for thread completion with thread.join() - - Cleans up all resources without async/await - """ - logger.debug("agent_id=<%s> | starting finalize cleanup", agent_id) - tool_registry = _AGENT_CLEANUP_STORE.pop(agent_uuid, None) - if not tool_registry: - return - - # Use synchronous cleanup to avoid run_async deadlocks during GC - tool_registry.cleanup() + self.tool_registry.cleanup() + + def __del__(self) -> None: + """Clean up resources when agent is garbage collected.""" + # __del__ is called even when an exception is thrown in the constructor, + # so there is no guarantee tool_registry was set.. + if hasattr(self, "tool_registry"): + self.tool_registry.cleanup() async def stream_async( self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py index 03f4f7aa3..2c79ceafc 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/experimental/tools/tool_provider.py @@ -27,23 +27,26 @@ async def load_tools(self, **kwargs: Any) -> Sequence["AgentTool"]: pass @abstractmethod - def add_consumer(self, id: Any, **kwargs: Any) -> None: + def add_consumer(self, consumer_id: Any, **kwargs: Any) -> None: """Add a consumer to this tool provider. Args: - id: Unique identifier for the consumer. + consumer_id: Unique identifier for the consumer. **kwargs: Additional arguments for future compatibility. """ pass @abstractmethod - def remove_consumer(self, id: Any, **kwargs: Any) -> None: + def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: """Remove a consumer from this tool provider. + This method must be idempotent - calling it multiple times with the same ID + should have no additional effect after the first call. + Provider may clean up resources when no consumers remain. Args: - id: Unique identifier for the consumer. + consumer_id: Unique identifier for the consumer. **kwargs: Additional arguments for future compatibility. """ pass diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 67ded10bb..a6402b09a 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -235,21 +235,24 @@ async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: return self._loaded_tools - def add_consumer(self, id: Any, **kwargs: Any) -> None: + def add_consumer(self, consumer_id: Any, **kwargs: Any) -> None: """Add a consumer to this tool provider. Synchronous to prevent GC deadlocks when called from Agent finalizers. """ - self._consumers.add(id) + self._consumers.add(consumer_id) logger.debug("added provider consumer, count=%d", len(self._consumers)) - def remove_consumer(self, id: Any, **kwargs: Any) -> None: + def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: """Remove a consumer from this tool provider. + This method is idempotent - calling it multiple times with the same ID + has no additional effect after the first call. + Synchronous to prevent GC deadlocks when called from Agent finalizers. Uses existing synchronous stop() method for safe cleanup. """ - self._consumers.discard(id) + self._consumers.discard(consumer_id) logger.debug("removed provider consumer, count=%d", len(self._consumers)) if not self._consumers and self._tool_provider_started: diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index f0fa0d655..964050597 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -127,12 +127,12 @@ def add_tool(tool: Any) -> None: # Case 5: ToolProvider elif isinstance(tool, ToolProvider): self._tool_providers.append(tool) + tool.add_consumer(self._registry_id) async def get_tools() -> Sequence[AgentTool]: return await tool.load_tools() provider_tools = run_async(get_tools) - tool.add_consumer(self._registry_id) for provider_tool in provider_tools: self.register_tool(provider_tool) @@ -662,7 +662,7 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: def cleanup(self, **kwargs: Any) -> None: """Synchronously clean up all tool providers in this registry.""" - # Attempt cleanup of all providers even if one fails to minimize resource leakage during garbage collection + # Attempt cleanup of all providers even if one fails to minimize resource leakage exceptions = [] for provider in self._tool_providers: try: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 251effc96..43fca1354 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2057,3 +2057,10 @@ def test_agent_tool_caller_interrupt(user): exp_message = r"cannot directly call tool during interrupt" with pytest.raises(RuntimeError, match=exp_message): agent.tool.test_tool() + + +def test_agent_del_before_tool_registry_set(): + """Test that Agent.__del__ doesn't fail if called before tool_registry is set.""" + agent = Agent() + del agent.tool_registry + agent.__del__() # Should not raise diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index ee0098adc..1bd4ef13f 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -260,3 +260,148 @@ def test_register_strands_tools_module_non_callable_function(): " Tool tool_with_spec_but_non_callable_function function is not callable", ): tool_registry.process_tools(["tests.fixtures.tool_with_spec_but_non_callable_function"]) + + +def test_tool_registry_cleanup_with_mcp_client(): + """Test that ToolRegistry cleanup properly handles MCP clients without orphaning threads.""" + from unittest.mock import AsyncMock, MagicMock + + from strands.tools.mcp import MCPClient + + # Create a mock MCP client that simulates a real tool provider + mock_transport = MagicMock() + mock_client = MCPClient(mock_transport) + + # Mock the client to avoid actual network operations + mock_client.load_tools = AsyncMock(return_value=[]) + + registry = ToolRegistry() + + # Use process_tools to properly register the client + registry.process_tools([mock_client]) + + # Verify the client was registered as a consumer + assert registry._registry_id in mock_client._consumers + + # Test cleanup calls remove_consumer + registry.cleanup() + + # Verify cleanup was attempted + assert registry._registry_id not in mock_client._consumers + + +def test_tool_registry_cleanup_exception_handling(): + """Test that ToolRegistry cleanup attempts all providers even if some fail.""" + from unittest.mock import MagicMock + + # Create mock providers - one that fails, one that succeeds + failing_provider = MagicMock() + failing_provider.remove_consumer.side_effect = Exception("Cleanup failed") + + working_provider = MagicMock() + + registry = ToolRegistry() + registry._tool_providers = [failing_provider, working_provider] + + # Cleanup should attempt both providers and raise the first exception + with pytest.raises(Exception, match="Cleanup failed"): + registry.cleanup() + + # Verify both providers were attempted + failing_provider.remove_consumer.assert_called_once() + working_provider.remove_consumer.assert_called_once() + + +def test_tool_registry_cleanup_idempotent(): + """Test that ToolRegistry cleanup is idempotent.""" + from unittest.mock import AsyncMock, MagicMock + + from strands.experimental.tools import ToolProvider + + provider = MagicMock(spec=ToolProvider) + provider.load_tools = AsyncMock(return_value=[]) + + registry = ToolRegistry() + + # Use process_tools to properly register the provider + registry.process_tools([provider]) + + # First cleanup should call remove_consumer + registry.cleanup() + provider.remove_consumer.assert_called_once_with(registry._registry_id) + + # Reset mock call count + provider.remove_consumer.reset_mock() + + # Second cleanup should call remove_consumer again (not idempotent yet) + # This test documents current behavior - registry cleanup is not idempotent + registry.cleanup() + provider.remove_consumer.assert_called_once_with(registry._registry_id) + + +def test_tool_registry_process_tools_exception_after_add_consumer(): + """Test that tool provider is still tracked for cleanup even if load_tools fails.""" + from unittest.mock import AsyncMock, MagicMock + + from strands.experimental.tools import ToolProvider + + # Create a mock tool provider that fails during load_tools + mock_provider = MagicMock(spec=ToolProvider) + mock_provider.add_consumer = MagicMock() + mock_provider.remove_consumer = MagicMock() + + async def failing_load_tools(): + raise Exception("Failed to load tools") + + mock_provider.load_tools = AsyncMock(side_effect=failing_load_tools) + + registry = ToolRegistry() + + # Processing should fail but provider should still be tracked + with pytest.raises(ValueError, match="Failed to load tool"): + registry.process_tools([mock_provider]) + + # Verify provider was added to registry for cleanup tracking + assert mock_provider in registry._tool_providers + + # Verify add_consumer was called before the failure + mock_provider.add_consumer.assert_called_once_with(registry._registry_id) + + # Cleanup should still work + registry.cleanup() + mock_provider.remove_consumer.assert_called_once_with(registry._registry_id) + + +def test_tool_registry_add_consumer_before_load_tools(): + """Test that add_consumer is called before load_tools to ensure cleanup tracking.""" + from unittest.mock import AsyncMock, MagicMock + + from strands.experimental.tools import ToolProvider + + # Create a mock tool provider that tracks call order + mock_provider = MagicMock(spec=ToolProvider) + call_order = [] + + def track_add_consumer(*args, **kwargs): + call_order.append("add_consumer") + + async def track_load_tools(*args, **kwargs): + call_order.append("load_tools") + return [] + + mock_provider.add_consumer.side_effect = track_add_consumer + mock_provider.load_tools = AsyncMock(side_effect=track_load_tools) + + registry = ToolRegistry() + + # Process the tool provider + registry.process_tools([mock_provider]) + + # Verify add_consumer was called before load_tools + assert call_order == ["add_consumer", "load_tools"] + + # Verify the provider was added to the registry for cleanup + assert mock_provider in registry._tool_providers + + # Verify add_consumer was called with the registry ID + mock_provider.add_consumer.assert_called_once_with(registry._registry_id) From 510309224e97a3d4c753da21ffc48a3af71e2976 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 23 Oct 2025 16:11:34 -0400 Subject: [PATCH 30/35] instrumentation fix --- src/strands/tools/mcp/mcp_client.py | 3 ++- .../strands/tools/mcp/test_mcp_instrumentation.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index a6402b09a..693dde327 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -33,6 +33,7 @@ from ...types.media import ImageFormat from ...types.tools import AgentTool, ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool +from .mcp_instrumentation import mcp_instrumentation from .mcp_types import MCPToolResult, MCPTransport logger = logging.getLogger(__name__) @@ -107,7 +108,7 @@ def __init__( self._tool_filters = tool_filters self._prefix = prefix - # mcp_instrumentation() + mcp_instrumentation() self._session_id = uuid.uuid4() self._log_debug_with_thread("initializing MCPClient connection") # Main thread blocks until future completesock diff --git a/tests/strands/tools/mcp/test_mcp_instrumentation.py b/tests/strands/tools/mcp/test_mcp_instrumentation.py index 2c730624e..85d533403 100644 --- a/tests/strands/tools/mcp/test_mcp_instrumentation.py +++ b/tests/strands/tools/mcp/test_mcp_instrumentation.py @@ -340,6 +340,21 @@ def __getattr__(self, name): class TestMCPInstrumentation: + def test_mcp_instrumentation_called_on_client_init(self): + """Test that mcp_instrumentation is called when MCPClient is initialized.""" + with patch("strands.tools.mcp.mcp_client.mcp_instrumentation") as mock_instrumentation: + # Mock transport + def mock_transport(): + read_stream = AsyncMock() + write_stream = AsyncMock() + return read_stream, write_stream + + # Create MCPClient instance - should call mcp_instrumentation + MCPClient(mock_transport) + + # Verify mcp_instrumentation was called + mock_instrumentation.assert_called_once() + def test_mcp_instrumentation_idempotent_with_multiple_clients(self): """Test that mcp_instrumentation is only called once even with multiple MCPClient instances.""" From 3f30664c8ff1b26688cb124bd36185e8cc57874f Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 23 Oct 2025 16:19:15 -0400 Subject: [PATCH 31/35] instrumentation fix --- src/strands/tools/mcp/mcp_client.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 693dde327..06deee621 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -12,6 +12,7 @@ import logging import threading import uuid +import warnings from asyncio import AbstractEventLoop from concurrent import futures from datetime import timedelta @@ -85,6 +86,10 @@ class MCPClient(ToolProvider): The connection runs in a background thread to avoid blocking the main application thread while maintaining communication with the MCP service. When structured content is available from MCP tools, it will be returned as the last item in the content array of the ToolResult. + + Warning: + This class implements the experimental ToolProvider interface and its methods + are subject to change. """ def __init__( @@ -108,6 +113,14 @@ def __init__( self._tool_filters = tool_filters self._prefix = prefix + # Warn about experimental ToolProvider interface + warnings.warn( + "MCPClient implements the experimental ToolProvider interface. " + "This interface and its methods are subject to change in future versions.", + FutureWarning, + stacklevel=2 + ) + mcp_instrumentation() self._session_id = uuid.uuid4() self._log_debug_with_thread("initializing MCPClient connection") From 983d6dd83fac835b0cd5346bfbb8a5542e5da9f3 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 23 Oct 2025 16:33:27 -0400 Subject: [PATCH 32/35] remove warning --- src/strands/tools/mcp/mcp_client.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 06deee621..847a9a056 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -113,14 +113,6 @@ def __init__( self._tool_filters = tool_filters self._prefix = prefix - # Warn about experimental ToolProvider interface - warnings.warn( - "MCPClient implements the experimental ToolProvider interface. " - "This interface and its methods are subject to change in future versions.", - FutureWarning, - stacklevel=2 - ) - mcp_instrumentation() self._session_id = uuid.uuid4() self._log_debug_with_thread("initializing MCPClient connection") From fbcc356cc3ac4406b22e81d488b6f00548fe60b0 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 23 Oct 2025 16:37:56 -0400 Subject: [PATCH 33/35] remove warning --- src/strands/tools/mcp/mcp_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 847a9a056..baeed9d13 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -12,7 +12,6 @@ import logging import threading import uuid -import warnings from asyncio import AbstractEventLoop from concurrent import futures from datetime import timedelta From a1279c53e205289890131214ae8c0e189ff48f72 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 24 Oct 2025 11:31:59 -0400 Subject: [PATCH 34/35] Update test_registry.py --- tests/strands/tools/test_registry.py | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 1bd4ef13f..6899b4535 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -2,14 +2,15 @@ Tests for the SDK tool registry module. """ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest import strands +from strands.experimental.tools import ToolProvider from strands.tools import PythonAgentTool from strands.tools.decorator import DecoratedFunctionTool, tool -from strands.tools.registry import ToolRegistry +from strands.tools.mcp import MCPClient def test_load_tool_from_filepath_failure(): @@ -264,10 +265,6 @@ def test_register_strands_tools_module_non_callable_function(): def test_tool_registry_cleanup_with_mcp_client(): """Test that ToolRegistry cleanup properly handles MCP clients without orphaning threads.""" - from unittest.mock import AsyncMock, MagicMock - - from strands.tools.mcp import MCPClient - # Create a mock MCP client that simulates a real tool provider mock_transport = MagicMock() mock_client = MCPClient(mock_transport) @@ -292,8 +289,6 @@ def test_tool_registry_cleanup_with_mcp_client(): def test_tool_registry_cleanup_exception_handling(): """Test that ToolRegistry cleanup attempts all providers even if some fail.""" - from unittest.mock import MagicMock - # Create mock providers - one that fails, one that succeeds failing_provider = MagicMock() failing_provider.remove_consumer.side_effect = Exception("Cleanup failed") @@ -314,10 +309,6 @@ def test_tool_registry_cleanup_exception_handling(): def test_tool_registry_cleanup_idempotent(): """Test that ToolRegistry cleanup is idempotent.""" - from unittest.mock import AsyncMock, MagicMock - - from strands.experimental.tools import ToolProvider - provider = MagicMock(spec=ToolProvider) provider.load_tools = AsyncMock(return_value=[]) @@ -341,10 +332,6 @@ def test_tool_registry_cleanup_idempotent(): def test_tool_registry_process_tools_exception_after_add_consumer(): """Test that tool provider is still tracked for cleanup even if load_tools fails.""" - from unittest.mock import AsyncMock, MagicMock - - from strands.experimental.tools import ToolProvider - # Create a mock tool provider that fails during load_tools mock_provider = MagicMock(spec=ToolProvider) mock_provider.add_consumer = MagicMock() @@ -374,10 +361,6 @@ async def failing_load_tools(): def test_tool_registry_add_consumer_before_load_tools(): """Test that add_consumer is called before load_tools to ensure cleanup tracking.""" - from unittest.mock import AsyncMock, MagicMock - - from strands.experimental.tools import ToolProvider - # Create a mock tool provider that tracks call order mock_provider = MagicMock(spec=ToolProvider) call_order = [] From cc217d677e4e996c60a9e0162ad6706018399c9a Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 24 Oct 2025 11:56:53 -0400 Subject: [PATCH 35/35] rebase --- src/strands/agent/agent.py | 2 +- src/strands/tools/mcp/mcp_client.py | 2 +- src/strands/types/exceptions.py | 18 +++++++++--------- tests/strands/agent/test_agent.py | 2 +- tests/strands/tools/test_registry.py | 1 + 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 3c735f23b..92c272c41 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -498,7 +498,7 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> category=DeprecationWarning, stacklevel=2, ) - + return run_async(lambda: self.structured_output_async(output_model, prompt)) async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index baeed9d13..61f3d9185 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -85,7 +85,7 @@ class MCPClient(ToolProvider): The connection runs in a background thread to avoid blocking the main application thread while maintaining communication with the MCP service. When structured content is available from MCP tools, it will be returned as the last item in the content array of the ToolResult. - + Warning: This class implements the experimental ToolProvider interface and its methods are subject to change. diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 349c6b0de..b9c5bc769 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -82,15 +82,15 @@ class ToolProviderException(Exception): pass - + class StructuredOutputException(Exception): - """Exception raised when structured output validation fails after maximum retry attempts.""" + """Exception raised when structured output validation fails after maximum retry attempts.""" - def __init__(self, message: str): - """Initialize the exception with details about the failure. + def __init__(self, message: str): + """Initialize the exception with details about the failure. - Args: - message: The error message describing the structured output failure - """ - self.message = message - super().__init__(message) + Args: + message: The error message describing the structured output failure + """ + self.message = message + super().__init__(message) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index d111acb48..403f858b5 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2069,7 +2069,7 @@ def test_agent_del_before_tool_registry_set(): del agent.tool_registry agent.__del__() # Should not raise - + def test_agent__call__invalid_tool_name(): @strands.tool def shell(command: str): diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 6899b4535..c700016f6 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -11,6 +11,7 @@ from strands.tools import PythonAgentTool from strands.tools.decorator import DecoratedFunctionTool, tool from strands.tools.mcp import MCPClient +from strands.tools.registry import ToolRegistry def test_load_tool_from_filepath_failure():