From da58123735728e023764a171915e6666449514e5 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 9 Sep 2025 17:25:23 -0400 Subject: [PATCH 1/3] fix(mcp): auto cleanup on exceptions occurring in __enter__ --- src/strands/models/bedrock.py | 1 + src/strands/tools/mcp/mcp_client.py | 43 +++++++++++++++++----- tests/strands/tools/mcp/test_mcp_client.py | 23 +++++++++++- tests_integ/test_mcp_client.py | 29 +++++++++++++++ 4 files changed, 84 insertions(+), 12 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 8909072f6..d0dfefdfa 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -47,6 +47,7 @@ DEFAULT_READ_TIMEOUT = 120 + class BedrockModel(Model): """AWS Bedrock model provider implementation. diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 5d9dd0b0f..fd7af92b1 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -83,11 +83,15 @@ def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_ti self._transport_callable = transport_callable self._background_thread: threading.Thread | None = None - self._background_thread_session: ClientSession - self._background_thread_event_loop: AbstractEventLoop + self._background_thread_session: ClientSession | None = None + self._background_thread_event_loop: AbstractEventLoop | None = None def __enter__(self) -> "MCPClient": - """Context manager entry point which initializes the MCP server connection.""" + """Context manager entry point which initializes the MCP server connection. + + TODO: Refactor to lazy initialization pattern following idiomatic Python. + Heavy work in __enter__ is non-idiomatic - should move connection logic to first method call instead. + """ return self.start() def __exit__(self, exc_type: BaseException, exc_val: BaseException, exc_tb: TracebackType) -> None: @@ -118,9 +122,15 @@ def start(self) -> "MCPClient": self._init_future.result(timeout=self._startup_timeout) self._log_debug_with_thread("the client initialization was successful") except futures.TimeoutError as e: - raise MCPClientInitializationError("background thread did not start in 30 seconds") from e + # Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit + self.stop(None, None, None) + raise MCPClientInitializationError( + f"background thread did not start in {self._startup_timeout} seconds" + ) from e except Exception as e: logger.exception("client failed to initialize") + # Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit + self.stop(None, None, None) raise MCPClientInitializationError("the client initialization failed") from e return self @@ -129,6 +139,9 @@ def stop( ) -> None: """Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources. + This method is defensive and can handle partial initialization states that may occur + if start() fails partway through initialization. + Args: exc_type: Exception type if an exception was raised in the context exc_val: Exception value if an exception was raised in the context @@ -136,14 +149,19 @@ def stop( """ self._log_debug_with_thread("exiting MCPClient context") - async def _set_close_event() -> None: - self._close_event.set() - - self._invoke_on_background_thread(_set_close_event()).result() - self._log_debug_with_thread("waiting for background thread to join") + # Only try to signal close event if we have a background thread if self._background_thread is not None: + # Signal close event if event loop exists + if self._background_thread_event_loop is not None: + + async def _set_close_event() -> None: + self._close_event.set() + + asyncio.run_coroutine_threadsafe(_set_close_event(), self._background_thread_event_loop) + + self._log_debug_with_thread("waiting for background thread to join") self._background_thread.join() - self._log_debug_with_thread("background thread joined, MCPClient context exited") + self._log_debug_with_thread("background thread joined, MCPClient context exited") # Reset fields to allow instance reuse self._init_future = futures.Future() @@ -165,6 +183,7 @@ def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedLi raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _list_tools_async() -> ListToolsResult: + assert self._background_thread_session is not None return await self._background_thread_session.list_tools(cursor=pagination_token) list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result() @@ -191,6 +210,7 @@ def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromp raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _list_prompts_async() -> ListPromptsResult: + assert self._background_thread_session is not None return await self._background_thread_session.list_prompts(cursor=pagination_token) list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result() @@ -215,6 +235,7 @@ def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResu raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _get_prompt_async() -> GetPromptResult: + assert self._background_thread_session is not None return await self._background_thread_session.get_prompt(prompt_id, arguments=args) get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result() @@ -250,6 +271,7 @@ def call_tool_sync( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _call_tool_async() -> MCPCallToolResult: + assert self._background_thread_session is not None return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds) try: @@ -285,6 +307,7 @@ async def call_tool_async( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _call_tool_async() -> MCPCallToolResult: + assert self._background_thread_session is not None return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds) try: diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index bd88382cd..53caa3da4 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -337,8 +337,12 @@ def test_enter_with_initialization_exception(mock_transport): client = MCPClient(mock_transport["transport_callable"]) - with pytest.raises(MCPClientInitializationError, match="the client initialization failed"): - client.start() + with patch.object(client, "stop") as mock_stop: + with pytest.raises(MCPClientInitializationError, match="the client initialization failed"): + client.start() + + # Verify stop() was called for cleanup + mock_stop.assert_called_once_with(None, None, None) def test_mcp_tool_result_type(): @@ -466,3 +470,18 @@ def test_get_prompt_sync_session_not_active(): with pytest.raises(MCPClientInitializationError, match="client session is not running"): client.get_prompt_sync("test_prompt_id", {}) + + +def test_timeout_initialization_cleanup(): + """Test that timeout during initialization properly cleans up.""" + + def slow_transport(): + time.sleep(5) + return MagicMock() + + client = MCPClient(slow_transport, startup_timeout=1) + + with patch.object(client, "stop") as mock_stop: + with pytest.raises(MCPClientInitializationError, match="background thread did not start in 1 seconds"): + client.start() + mock_stop.assert_called_once_with(None, None, None) diff --git a/tests_integ/test_mcp_client.py b/tests_integ/test_mcp_client.py index 3de249435..4e358f4f2 100644 --- a/tests_integ/test_mcp_client.py +++ b/tests_integ/test_mcp_client.py @@ -15,6 +15,7 @@ from strands.tools.mcp.mcp_client import MCPClient from strands.tools.mcp.mcp_types import MCPTransport from strands.types.content import Message +from strands.types.exceptions import MCPClientInitializationError from strands.types.tools import ToolUse @@ -268,3 +269,31 @@ def transport_callback() -> MCPTransport: def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block] + + +def test_mcp_client_timeout_integration(): + """Integration test for timeout scenario that caused hanging.""" + import threading + + from mcp import StdioServerParameters, stdio_client + + def slow_transport(): + time.sleep(4) # Longer than timeout + return stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + + client = MCPClient(slow_transport, startup_timeout=2) + initial_threads = threading.active_count() + + # First attempt should timeout + with pytest.raises(MCPClientInitializationError, match="background thread did not start in 2 seconds"): + with client: + pass + + time.sleep(1) # Allow cleanup + assert threading.active_count() == initial_threads # No thread leak + + # Should be able to recover by increasing timeout + client._startup_timeout = 60 + with client: + tools = client.list_tools_sync() + assert len(tools) >= 0 # Should work now From 12e3037608fb70b67e10b45aec4e260082198b67 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 9 Sep 2025 18:32:35 -0400 Subject: [PATCH 2/3] tests: add more tests --- src/strands/tools/mcp/mcp_client.py | 23 +++++++------- tests/strands/tools/mcp/test_mcp_client.py | 36 ++++++++++++++++++++++ 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index fd7af92b1..ea14bf81f 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -16,7 +16,7 @@ from concurrent import futures from datetime import timedelta from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union +from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union, cast from mcp import ClientSession, ListToolsResult from mcp.types import CallToolResult as MCPCallToolResult @@ -157,7 +157,7 @@ def stop( async def _set_close_event() -> None: self._close_event.set() - asyncio.run_coroutine_threadsafe(_set_close_event(), self._background_thread_event_loop) + self._invoke_on_background_thread(_set_close_event()).result() self._log_debug_with_thread("waiting for background thread to join") self._background_thread.join() @@ -183,8 +183,7 @@ def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedLi raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _list_tools_async() -> ListToolsResult: - assert self._background_thread_session is not None - return await self._background_thread_session.list_tools(cursor=pagination_token) + return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token) list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result() self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools)) @@ -210,8 +209,7 @@ def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromp raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _list_prompts_async() -> ListPromptsResult: - assert self._background_thread_session is not None - return await self._background_thread_session.list_prompts(cursor=pagination_token) + return await cast(ClientSession, self._background_thread_session).list_prompts(cursor=pagination_token) list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result() self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts)) @@ -235,8 +233,7 @@ def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResu raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _get_prompt_async() -> GetPromptResult: - assert self._background_thread_session is not None - return await self._background_thread_session.get_prompt(prompt_id, arguments=args) + return await cast(ClientSession, self._background_thread_session).get_prompt(prompt_id, arguments=args) get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result() self._log_debug_with_thread("received prompt from MCP server") @@ -271,8 +268,9 @@ def call_tool_sync( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _call_tool_async() -> MCPCallToolResult: - assert self._background_thread_session is not None - return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds) + return await cast(ClientSession, self._background_thread_session).call_tool( + name, arguments, read_timeout_seconds + ) try: call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async()).result() @@ -307,8 +305,9 @@ async def call_tool_async( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _call_tool_async() -> MCPCallToolResult: - assert self._background_thread_session is not None - return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds) + return await cast(ClientSession, self._background_thread_session).call_tool( + name, arguments, read_timeout_seconds + ) try: future = self._invoke_on_background_thread(_call_tool_async()) diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 53caa3da4..8514a67d4 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -485,3 +485,39 @@ def slow_transport(): with pytest.raises(MCPClientInitializationError, match="background thread did not start in 1 seconds"): client.start() mock_stop.assert_called_once_with(None, None, None) + + +def test_stop_with_no_background_thread(): + """Test that stop() handles the case when no background thread exists.""" + client = MCPClient(MagicMock()) + + # Ensure no background thread exists + assert client._background_thread is None + + # Mock join to verify it's not called + with patch("threading.Thread.join") as mock_join: + client.stop(None, None, None) + mock_join.assert_not_called() + + # Verify cleanup occurred + assert client._background_thread is None + + +def test_stop_with_background_thread_but_no_event_loop(): + """Test that stop() handles the case when background thread exists but event loop is None.""" + client = MCPClient(MagicMock()) + + # Mock a background thread without event loop + mock_thread = MagicMock() + mock_thread.join = MagicMock() + client._background_thread = mock_thread + client._background_thread_event_loop = None + + # Should not raise any exceptions and should join the thread + client.stop(None, None, None) + + # Verify thread was joined + mock_thread.join.assert_called_once() + + # Verify cleanup occurred + assert client._background_thread is None From 7c773c8b425ebf98f344f2258af779ee377d812b Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 11 Sep 2025 16:45:58 -0400 Subject: [PATCH 3/3] respond to comments --- src/strands/tools/mcp/mcp_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index ea14bf81f..402005604 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -122,6 +122,7 @@ def start(self) -> "MCPClient": self._init_future.result(timeout=self._startup_timeout) self._log_debug_with_thread("the client initialization was successful") except futures.TimeoutError as e: + logger.exception("client initialization timed out") # Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit self.stop(None, None, None) raise MCPClientInitializationError( @@ -161,7 +162,7 @@ async def _set_close_event() -> None: self._log_debug_with_thread("waiting for background thread to join") self._background_thread.join() - self._log_debug_with_thread("background thread joined, MCPClient context exited") + self._log_debug_with_thread("background thread is closed, MCPClient context exited") # Reset fields to allow instance reuse self._init_future = futures.Future()