diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 402005604..f810fed06 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -143,6 +143,18 @@ def stop( This method is defensive and can handle partial initialization states that may occur if start() fails partway through initialization. + Resources to cleanup: + - _background_thread: Thread running the async event loop + - _background_thread_session: MCP ClientSession (auto-closed by context manager) + - _background_thread_event_loop: AsyncIO event loop in background thread + - _close_event: AsyncIO event to signal thread shutdown + - _init_future: Future for initialization synchronization + + Cleanup order: + 1. Signal close event to background thread (if session initialized) + 2. Wait for background thread to complete + 3. Reset all state for reuse + Args: exc_type: Exception type if an exception was raised in the context exc_val: Exception value if an exception was raised in the context @@ -158,7 +170,9 @@ def stop( async def _set_close_event() -> None: self._close_event.set() - self._invoke_on_background_thread(_set_close_event()).result() + # Not calling _invoke_on_background_thread since the session does not need to exist + # we only need the thread and event loop to exist. + asyncio.run_coroutine_threadsafe(coro=_set_close_event(), loop=self._background_thread_event_loop) self._log_debug_with_thread("waiting for background thread to join") self._background_thread.join() @@ -168,6 +182,8 @@ async def _set_close_event() -> None: self._init_future = futures.Future() self._close_event = asyncio.Event() self._background_thread = None + self._background_thread_session = None + self._background_thread_event_loop = None self._session_id = uuid.uuid4() def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedList[MCPAgentTool]: diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 8514a67d4..d161df6d4 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -521,3 +521,22 @@ def test_stop_with_background_thread_but_no_event_loop(): # Verify cleanup occurred assert client._background_thread is None + + +def test_mcp_client_state_reset_after_timeout(): + """Test that all client state is properly reset after timeout.""" + def slow_transport(): + time.sleep(4) # Longer than timeout + return MagicMock() + + client = MCPClient(slow_transport, startup_timeout=2) + + # First attempt should timeout + with pytest.raises(MCPClientInitializationError, match="background thread did not start in 2 seconds"): + client.start() + + # Verify all state is reset + assert client._background_thread is None + assert client._background_thread_session is None + assert client._background_thread_event_loop is None + assert not client._init_future.done() # New future created \ No newline at end of file diff --git a/tests_integ/test_mcp_client.py b/tests_integ/test_mcp_client.py index 4e358f4f2..0723750c2 100644 --- a/tests_integ/test_mcp_client.py +++ b/tests_integ/test_mcp_client.py @@ -297,3 +297,4 @@ def slow_transport(): with client: tools = client.list_tools_sync() assert len(tools) >= 0 # Should work now +