From b8d50bd96a84bcfa1401a6353e317765a9ce5b85 Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Thu, 18 Sep 2025 09:49:08 -0400 Subject: [PATCH] fix: make mcp_instrumentation idempotent to prevent recursion errors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add module-level flag _instrumentation_applied to track patch state - Return early from mcp_instrumentation() if already applied - Prevents wrapper accumulation that causes RecursionError with multiple MCPClient instances - Add integration tests for multiple client creation and thread safety Fixes #869 🤖 Assisted by Amazon Q Developer --- src/strands/tools/mcp/mcp_client.py | 2 +- src/strands/tools/mcp/mcp_instrumentation.py | 13 ++++++++ tests/strands/tools/mcp/test_mcp_client.py | 7 ++-- .../tools/mcp/test_mcp_instrumentation.py | 33 +++++++++++++++++++ tests_integ/test_mcp_client.py | 1 - 5 files changed, 51 insertions(+), 5 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index f810fed06..96e80385f 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -149,7 +149,7 @@ def stop( - _background_thread_event_loop: AsyncIO event loop in background thread - _close_event: AsyncIO event to signal thread shutdown - _init_future: Future for initialization synchronization - + Cleanup order: 1. Signal close event to background thread (if session initialized) 2. Wait for background thread to complete diff --git a/src/strands/tools/mcp/mcp_instrumentation.py b/src/strands/tools/mcp/mcp_instrumentation.py index 338721db5..f8ab3bc80 100644 --- a/src/strands/tools/mcp/mcp_instrumentation.py +++ b/src/strands/tools/mcp/mcp_instrumentation.py @@ -18,6 +18,9 @@ from opentelemetry import context, propagate from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper +# Module-level flag to ensure instrumentation is applied only once +_instrumentation_applied = False + @dataclass(slots=True, frozen=True) class ItemWithContext: @@ -48,7 +51,14 @@ def mcp_instrumentation() -> None: - Adding OpenTelemetry context to the _meta field of MCP requests - Extracting and activating context on the server side - Preserving context across async message processing boundaries + + This function is idempotent - multiple calls will not accumulate wrappers. """ + global _instrumentation_applied + + # Return early if instrumentation has already been applied + if _instrumentation_applied: + return def patch_mcp_client(wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any) -> Any: """Patch MCP client to inject OpenTelemetry context into tool calls. @@ -167,6 +177,9 @@ def traced_method( "mcp.server.session", ) + # Mark instrumentation as applied + _instrumentation_applied = True + class TransportContextExtractingReader(ObjectProxy): """A proxy reader that extracts OpenTelemetry context from MCP messages. diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index d161df6d4..67d8fe558 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -522,15 +522,16 @@ def test_stop_with_background_thread_but_no_event_loop(): # Verify cleanup occurred assert client._background_thread is None - + def test_mcp_client_state_reset_after_timeout(): """Test that all client state is properly reset after timeout.""" + def slow_transport(): time.sleep(4) # Longer than timeout return MagicMock() client = MCPClient(slow_transport, startup_timeout=2) - + # First attempt should timeout with pytest.raises(MCPClientInitializationError, match="background thread did not start in 2 seconds"): client.start() @@ -539,4 +540,4 @@ def slow_transport(): assert client._background_thread is None assert client._background_thread_session is None assert client._background_thread_event_loop is None - assert not client._init_future.done() # New future created \ No newline at end of file + assert not client._init_future.done() # New future created diff --git a/tests/strands/tools/mcp/test_mcp_instrumentation.py b/tests/strands/tools/mcp/test_mcp_instrumentation.py index 61a485777..2c730624e 100644 --- a/tests/strands/tools/mcp/test_mcp_instrumentation.py +++ b/tests/strands/tools/mcp/test_mcp_instrumentation.py @@ -5,6 +5,7 @@ from mcp.types import JSONRPCMessage, JSONRPCRequest from opentelemetry import context, propagate +from strands.tools.mcp.mcp_client import MCPClient from strands.tools.mcp.mcp_instrumentation import ( ItemWithContext, SessionContextAttachingReader, @@ -14,6 +15,17 @@ ) +@pytest.fixture(autouse=True) +def reset_mcp_instrumentation(): + """Reset MCP instrumentation state before each test.""" + import strands.tools.mcp.mcp_instrumentation as mcp_inst + + mcp_inst._instrumentation_applied = False + yield + # Reset after test too + mcp_inst._instrumentation_applied = False + + class TestItemWithContext: def test_item_with_context_creation(self): """Test that ItemWithContext correctly stores item and context.""" @@ -328,6 +340,27 @@ def __getattr__(self, name): class TestMCPInstrumentation: + def test_mcp_instrumentation_idempotent_with_multiple_clients(self): + """Test that mcp_instrumentation is only called once even with multiple MCPClient instances.""" + + # Mock the wrap_function_wrapper to count calls + with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + # Mock transport + def mock_transport(): + read_stream = AsyncMock() + write_stream = AsyncMock() + return read_stream, write_stream + + # Create first MCPClient instance - should apply instrumentation + MCPClient(mock_transport) + first_call_count = mock_wrap.call_count + + # Create second MCPClient instance - should NOT apply instrumentation again + MCPClient(mock_transport) + + # wrap_function_wrapper should not be called again for the second client + assert mock_wrap.call_count == first_call_count + def test_mcp_instrumentation_calls_wrap_function_wrapper(self): """Test that mcp_instrumentation calls the expected wrapper functions.""" with ( diff --git a/tests_integ/test_mcp_client.py b/tests_integ/test_mcp_client.py index 0723750c2..4e358f4f2 100644 --- a/tests_integ/test_mcp_client.py +++ b/tests_integ/test_mcp_client.py @@ -297,4 +297,3 @@ def slow_transport(): with client: tools = client.list_tools_sync() assert len(tools) >= 0 # Should work now -