Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

DEFAULT_READ_TIMEOUT = 120


class BedrockModel(Model):
"""AWS Bedrock model provider implementation.

Expand Down
55 changes: 39 additions & 16 deletions src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from concurrent import futures
from datetime import timedelta
from types import TracebackType
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union, cast

from mcp import ClientSession, ListToolsResult
from mcp.types import CallToolResult as MCPCallToolResult
Expand Down Expand Up @@ -83,11 +83,15 @@ def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_ti
self._transport_callable = transport_callable

self._background_thread: threading.Thread | None = None
self._background_thread_session: ClientSession
self._background_thread_event_loop: AbstractEventLoop
self._background_thread_session: ClientSession | None = None
self._background_thread_event_loop: AbstractEventLoop | None = None

def __enter__(self) -> "MCPClient":
"""Context manager entry point which initializes the MCP server connection."""
"""Context manager entry point which initializes the MCP server connection.

TODO: Refactor to lazy initialization pattern following idiomatic Python.
Heavy work in __enter__ is non-idiomatic - should move connection logic to first method call instead.
"""
return self.start()

def __exit__(self, exc_type: BaseException, exc_val: BaseException, exc_tb: TracebackType) -> None:
Expand Down Expand Up @@ -118,9 +122,16 @@ def start(self) -> "MCPClient":
self._init_future.result(timeout=self._startup_timeout)
self._log_debug_with_thread("the client initialization was successful")
except futures.TimeoutError as e:
raise MCPClientInitializationError("background thread did not start in 30 seconds") from e
logger.exception("client initialization timed out")
# Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit
self.stop(None, None, None)
raise MCPClientInitializationError(
f"background thread did not start in {self._startup_timeout} seconds"
) from e
except Exception as e:
logger.exception("client failed to initialize")
# Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit
self.stop(None, None, None)
raise MCPClientInitializationError("the client initialization failed") from e
return self

Expand All @@ -129,21 +140,29 @@ def stop(
) -> None:
"""Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources.

This method is defensive and can handle partial initialization states that may occur
if start() fails partway through initialization.

Args:
exc_type: Exception type if an exception was raised in the context
exc_val: Exception value if an exception was raised in the context
exc_tb: Exception traceback if an exception was raised in the context
"""
self._log_debug_with_thread("exiting MCPClient context")

async def _set_close_event() -> None:
self._close_event.set()

self._invoke_on_background_thread(_set_close_event()).result()
self._log_debug_with_thread("waiting for background thread to join")
# Only try to signal close event if we have a background thread
if self._background_thread is not None:
# Signal close event if event loop exists
if self._background_thread_event_loop is not None:

async def _set_close_event() -> None:
self._close_event.set()

self._invoke_on_background_thread(_set_close_event()).result()

self._log_debug_with_thread("waiting for background thread to join")
self._background_thread.join()
self._log_debug_with_thread("background thread joined, MCPClient context exited")
self._log_debug_with_thread("background thread is closed, MCPClient context exited")

# Reset fields to allow instance reuse
self._init_future = futures.Future()
Expand All @@ -165,7 +184,7 @@ def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedLi
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

async def _list_tools_async() -> ListToolsResult:
return await self._background_thread_session.list_tools(cursor=pagination_token)
return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token)

list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result()
self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools))
Expand All @@ -191,7 +210,7 @@ def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromp
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

async def _list_prompts_async() -> ListPromptsResult:
return await self._background_thread_session.list_prompts(cursor=pagination_token)
return await cast(ClientSession, self._background_thread_session).list_prompts(cursor=pagination_token)

list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result()
self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts))
Expand All @@ -215,7 +234,7 @@ def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResu
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

async def _get_prompt_async() -> GetPromptResult:
return await self._background_thread_session.get_prompt(prompt_id, arguments=args)
return await cast(ClientSession, self._background_thread_session).get_prompt(prompt_id, arguments=args)

get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result()
self._log_debug_with_thread("received prompt from MCP server")
Expand Down Expand Up @@ -250,7 +269,9 @@ def call_tool_sync(
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

async def _call_tool_async() -> MCPCallToolResult:
return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds)
return await cast(ClientSession, self._background_thread_session).call_tool(
name, arguments, read_timeout_seconds
)

try:
call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async()).result()
Expand Down Expand Up @@ -285,7 +306,9 @@ async def call_tool_async(
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

async def _call_tool_async() -> MCPCallToolResult:
return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds)
return await cast(ClientSession, self._background_thread_session).call_tool(
name, arguments, read_timeout_seconds
)

try:
future = self._invoke_on_background_thread(_call_tool_async())
Expand Down
59 changes: 57 additions & 2 deletions tests/strands/tools/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,12 @@ def test_enter_with_initialization_exception(mock_transport):

client = MCPClient(mock_transport["transport_callable"])

with pytest.raises(MCPClientInitializationError, match="the client initialization failed"):
client.start()
with patch.object(client, "stop") as mock_stop:
with pytest.raises(MCPClientInitializationError, match="the client initialization failed"):
client.start()

# Verify stop() was called for cleanup
mock_stop.assert_called_once_with(None, None, None)


def test_mcp_tool_result_type():
Expand Down Expand Up @@ -466,3 +470,54 @@ def test_get_prompt_sync_session_not_active():

with pytest.raises(MCPClientInitializationError, match="client session is not running"):
client.get_prompt_sync("test_prompt_id", {})


def test_timeout_initialization_cleanup():
"""Test that timeout during initialization properly cleans up."""

def slow_transport():
time.sleep(5)
return MagicMock()

client = MCPClient(slow_transport, startup_timeout=1)

with patch.object(client, "stop") as mock_stop:
with pytest.raises(MCPClientInitializationError, match="background thread did not start in 1 seconds"):
client.start()
mock_stop.assert_called_once_with(None, None, None)


def test_stop_with_no_background_thread():
"""Test that stop() handles the case when no background thread exists."""
client = MCPClient(MagicMock())

# Ensure no background thread exists
assert client._background_thread is None

# Mock join to verify it's not called
with patch("threading.Thread.join") as mock_join:
client.stop(None, None, None)
mock_join.assert_not_called()

# Verify cleanup occurred
assert client._background_thread is None


def test_stop_with_background_thread_but_no_event_loop():
"""Test that stop() handles the case when background thread exists but event loop is None."""
client = MCPClient(MagicMock())

# Mock a background thread without event loop
mock_thread = MagicMock()
mock_thread.join = MagicMock()
client._background_thread = mock_thread
client._background_thread_event_loop = None

# Should not raise any exceptions and should join the thread
client.stop(None, None, None)

# Verify thread was joined
mock_thread.join.assert_called_once()

# Verify cleanup occurred
assert client._background_thread is None
29 changes: 29 additions & 0 deletions tests_integ/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from strands.tools.mcp.mcp_client import MCPClient
from strands.tools.mcp.mcp_types import MCPTransport
from strands.types.content import Message
from strands.types.exceptions import MCPClientInitializationError
from strands.types.tools import ToolUse


Expand Down Expand Up @@ -268,3 +269,31 @@ def transport_callback() -> MCPTransport:

def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]:
return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block]


def test_mcp_client_timeout_integration():
"""Integration test for timeout scenario that caused hanging."""
import threading

from mcp import StdioServerParameters, stdio_client

def slow_transport():
time.sleep(4) # Longer than timeout
return stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))

client = MCPClient(slow_transport, startup_timeout=2)
initial_threads = threading.active_count()

# First attempt should timeout
with pytest.raises(MCPClientInitializationError, match="background thread did not start in 2 seconds"):
with client:
pass

time.sleep(1) # Allow cleanup
assert threading.active_count() == initial_threads # No thread leak

# Should be able to recover by increasing timeout
client._startup_timeout = 60
with client:
tools = client.list_tools_sync()
assert len(tools) >= 0 # Should work now
Loading