Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def stop(
- _background_thread_event_loop: AsyncIO event loop in background thread
- _close_event: AsyncIO event to signal thread shutdown
- _init_future: Future for initialization synchronization

Cleanup order:
1. Signal close event to background thread (if session initialized)
2. Wait for background thread to complete
Expand Down
13 changes: 13 additions & 0 deletions src/strands/tools/mcp/mcp_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from opentelemetry import context, propagate
from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper

# Module-level flag to ensure instrumentation is applied only once
_instrumentation_applied = False


@dataclass(slots=True, frozen=True)
class ItemWithContext:
Expand Down Expand Up @@ -48,7 +51,14 @@ def mcp_instrumentation() -> None:
- Adding OpenTelemetry context to the _meta field of MCP requests
- Extracting and activating context on the server side
- Preserving context across async message processing boundaries

This function is idempotent - multiple calls will not accumulate wrappers.
"""
global _instrumentation_applied

# Return early if instrumentation has already been applied
if _instrumentation_applied:
return

def patch_mcp_client(wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any) -> Any:
"""Patch MCP client to inject OpenTelemetry context into tool calls.
Expand Down Expand Up @@ -167,6 +177,9 @@ def traced_method(
"mcp.server.session",
)

# Mark instrumentation as applied
_instrumentation_applied = True


class TransportContextExtractingReader(ObjectProxy):
"""A proxy reader that extracts OpenTelemetry context from MCP messages.
Expand Down
7 changes: 4 additions & 3 deletions tests/strands/tools/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,15 +522,16 @@ def test_stop_with_background_thread_but_no_event_loop():
# Verify cleanup occurred
assert client._background_thread is None


def test_mcp_client_state_reset_after_timeout():
"""Test that all client state is properly reset after timeout."""

def slow_transport():
time.sleep(4) # Longer than timeout
return MagicMock()

client = MCPClient(slow_transport, startup_timeout=2)

# First attempt should timeout
with pytest.raises(MCPClientInitializationError, match="background thread did not start in 2 seconds"):
client.start()
Expand All @@ -539,4 +540,4 @@ def slow_transport():
assert client._background_thread is None
assert client._background_thread_session is None
assert client._background_thread_event_loop is None
assert not client._init_future.done() # New future created
assert not client._init_future.done() # New future created
33 changes: 33 additions & 0 deletions tests/strands/tools/mcp/test_mcp_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mcp.types import JSONRPCMessage, JSONRPCRequest
from opentelemetry import context, propagate

from strands.tools.mcp.mcp_client import MCPClient
from strands.tools.mcp.mcp_instrumentation import (
ItemWithContext,
SessionContextAttachingReader,
Expand All @@ -14,6 +15,17 @@
)


@pytest.fixture(autouse=True)
def reset_mcp_instrumentation():
"""Reset MCP instrumentation state before each test."""
import strands.tools.mcp.mcp_instrumentation as mcp_inst

mcp_inst._instrumentation_applied = False
yield
# Reset after test too
mcp_inst._instrumentation_applied = False


class TestItemWithContext:
def test_item_with_context_creation(self):
"""Test that ItemWithContext correctly stores item and context."""
Expand Down Expand Up @@ -328,6 +340,27 @@ def __getattr__(self, name):


class TestMCPInstrumentation:
def test_mcp_instrumentation_idempotent_with_multiple_clients(self):
"""Test that mcp_instrumentation is only called once even with multiple MCPClient instances."""

# Mock the wrap_function_wrapper to count calls
with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap:
# Mock transport
def mock_transport():
read_stream = AsyncMock()
write_stream = AsyncMock()
return read_stream, write_stream

# Create first MCPClient instance - should apply instrumentation
MCPClient(mock_transport)
first_call_count = mock_wrap.call_count

# Create second MCPClient instance - should NOT apply instrumentation again
MCPClient(mock_transport)

# wrap_function_wrapper should not be called again for the second client
assert mock_wrap.call_count == first_call_count

def test_mcp_instrumentation_calls_wrap_function_wrapper(self):
"""Test that mcp_instrumentation calls the expected wrapper functions."""
with (
Expand Down
1 change: 0 additions & 1 deletion tests_integ/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,4 +297,3 @@ def slow_transport():
with client:
tools = client.list_tools_sync()
assert len(tools) >= 0 # Should work now

Loading