diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 784636fd0..8c21baa4a 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -20,6 +20,7 @@ from mcp import ClientSession, ListToolsResult from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import GetPromptResult, ListPromptsResult from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent @@ -165,6 +166,54 @@ async def _list_tools_async() -> ListToolsResult: self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) + def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromptsResult: + """Synchronously retrieves the list of available prompts from the MCP server. + + This method calls the asynchronous list_prompts method on the MCP session + and returns the raw ListPromptsResult with pagination support. + + Args: + pagination_token: Optional token for pagination + + Returns: + ListPromptsResult: The raw MCP response containing prompts and pagination info + """ + self._log_debug_with_thread("listing MCP prompts synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _list_prompts_async() -> ListPromptsResult: + return await self._background_thread_session.list_prompts(cursor=pagination_token) + + list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result() + self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts)) + for prompt in list_prompts_result.prompts: + self._log_debug_with_thread(prompt.name) + + return list_prompts_result + + def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResult: + """Synchronously retrieves a prompt from the MCP server. + + Args: + prompt_id: The ID of the prompt to retrieve + args: Optional arguments to pass to the prompt + + Returns: + GetPromptResult: The prompt response from the MCP server + """ + self._log_debug_with_thread("getting MCP prompt synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _get_prompt_async() -> GetPromptResult: + return await self._background_thread_session.get_prompt(prompt_id, arguments=args) + + get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result() + self._log_debug_with_thread("received prompt from MCP server") + + return get_prompt_result + def call_tool_sync( self, tool_use_id: str, diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 3d3792c71..bd88382cd 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -4,6 +4,7 @@ import pytest from mcp import ListToolsResult from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import GetPromptResult, ListPromptsResult, Prompt, PromptMessage from mcp.types import TextContent as MCPTextContent from mcp.types import Tool as MCPTool @@ -404,3 +405,64 @@ def test_exception_when_future_not_running(): # Verify that set_exception was not called since the future was not running mock_future.set_exception.assert_not_called() + + +# Prompt Tests - Sync Methods + + +def test_list_prompts_sync(mock_transport, mock_session): + """Test that list_prompts_sync correctly retrieves prompts.""" + mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1") + mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_prompts_sync() + + mock_session.list_prompts.assert_called_once_with(cursor=None) + assert len(result.prompts) == 1 + assert result.prompts[0].name == "test_prompt" + assert result.nextCursor is None + + +def test_list_prompts_sync_with_pagination_token(mock_transport, mock_session): + """Test that list_prompts_sync correctly passes pagination token and returns next cursor.""" + mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1") + mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt], nextCursor="next_page_token") + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_prompts_sync(pagination_token="current_page_token") + + mock_session.list_prompts.assert_called_once_with(cursor="current_page_token") + assert len(result.prompts) == 1 + assert result.prompts[0].name == "test_prompt" + assert result.nextCursor == "next_page_token" + + +def test_list_prompts_sync_session_not_active(): + """Test that list_prompts_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.list_prompts_sync() + + +def test_get_prompt_sync(mock_transport, mock_session): + """Test that get_prompt_sync correctly retrieves a prompt.""" + mock_message = PromptMessage(role="user", content=MCPTextContent(type="text", text="This is a test prompt")) + mock_session.get_prompt.return_value = GetPromptResult(messages=[mock_message]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.get_prompt_sync("test_prompt_id", {"key": "value"}) + + mock_session.get_prompt.assert_called_once_with("test_prompt_id", arguments={"key": "value"}) + assert len(result.messages) == 1 + assert result.messages[0].role == "user" + assert result.messages[0].content.text == "This is a test prompt" + + +def test_get_prompt_sync_session_not_active(): + """Test that get_prompt_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.get_prompt_sync("test_prompt_id", {}) diff --git a/tests_integ/test_mcp_client.py b/tests_integ/test_mcp_client.py index ebd4f5896..3de249435 100644 --- a/tests_integ/test_mcp_client.py +++ b/tests_integ/test_mcp_client.py @@ -18,18 +18,17 @@ from strands.types.tools import ToolUse -def start_calculator_server(transport: Literal["sse", "streamable-http"], port=int): +def start_comprehensive_mcp_server(transport: Literal["sse", "streamable-http"], port=int): """ - Initialize and start an MCP calculator server for integration testing. + Initialize and start a comprehensive MCP server for integration testing. - This function creates a FastMCP server instance that provides a simple - calculator tool for performing addition operations. The server uses - Server-Sent Events (SSE) transport for communication, making it accessible - over HTTP. + This function creates a FastMCP server instance that provides tools, prompts, + and resources all in one server for comprehensive testing. The server uses + Server-Sent Events (SSE) or streamable HTTP transport for communication. """ from mcp.server import FastMCP - mcp = FastMCP("Calculator Server", port=port) + mcp = FastMCP("Comprehensive MCP Server", port=port) @mcp.tool(description="Calculator tool which performs calculations") def calculator(x: int, y: int) -> int: @@ -44,6 +43,15 @@ def generate_custom_image() -> MCPImageContent: except Exception as e: print("Error while generating custom image: {}".format(e)) + # Prompts + @mcp.prompt(description="A greeting prompt template") + def greeting_prompt(name: str = "World") -> str: + return f"Hello, {name}! How are you today?" + + @mcp.prompt(description="A math problem prompt template") + def math_prompt(operation: str = "addition", difficulty: str = "easy") -> str: + return f"Create a {difficulty} {operation} math problem and solve it step by step." + mcp.run(transport=transport) @@ -58,8 +66,9 @@ def test_mcp_client(): {'role': 'assistant', 'content': [{'text': '\n\nThe result of adding 1 and 2 is 3.'}]} """ # noqa: E501 + # Start comprehensive server with tools, prompts, and resources server_thread = threading.Thread( - target=start_calculator_server, kwargs={"transport": "sse", "port": 8000}, daemon=True + target=start_comprehensive_mcp_server, kwargs={"transport": "sse", "port": 8000}, daemon=True ) server_thread.start() time.sleep(2) # wait for server to startup completely @@ -68,8 +77,14 @@ def test_mcp_client(): stdio_mcp_client = MCPClient( lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) ) + with sse_mcp_client, stdio_mcp_client: - agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync()) + # Test Tools functionality + sse_tools = sse_mcp_client.list_tools_sync() + stdio_tools = stdio_mcp_client.list_tools_sync() + all_tools = sse_tools + stdio_tools + + agent = Agent(tools=all_tools) agent("add 1 and 2, then echo the result back to me") tool_use_content_blocks = _messages_to_content_blocks(agent.messages) @@ -88,6 +103,43 @@ def test_mcp_client(): ] ) + # Test Prompts functionality + prompts_result = sse_mcp_client.list_prompts_sync() + assert len(prompts_result.prompts) >= 2 # We expect at least greeting and math prompts + + prompt_names = [prompt.name for prompt in prompts_result.prompts] + assert "greeting_prompt" in prompt_names + assert "math_prompt" in prompt_names + + # Test get_prompt_sync with greeting prompt + greeting_result = sse_mcp_client.get_prompt_sync("greeting_prompt", {"name": "Alice"}) + assert len(greeting_result.messages) > 0 + prompt_text = greeting_result.messages[0].content.text + assert "Hello, Alice!" in prompt_text + assert "How are you today?" in prompt_text + + # Test get_prompt_sync with math prompt + math_result = sse_mcp_client.get_prompt_sync( + "math_prompt", {"operation": "multiplication", "difficulty": "medium"} + ) + assert len(math_result.messages) > 0 + math_text = math_result.messages[0].content.text + assert "multiplication" in math_text + assert "medium" in math_text + assert "step by step" in math_text + + # Test pagination support for prompts + prompts_with_token = sse_mcp_client.list_prompts_sync(pagination_token=None) + assert len(prompts_with_token.prompts) >= 0 + + # Test pagination support for tools (existing functionality) + tools_with_token = sse_mcp_client.list_tools_sync(pagination_token=None) + assert len(tools_with_token) >= 0 + + # TODO: Add resources testing when resources are implemented + # resources_result = sse_mcp_client.list_resources_sync() + # assert len(resources_result.resources) >= 0 + tool_use_id = "test-structured-content-123" result = stdio_mcp_client.call_tool_sync( tool_use_id=tool_use_id, @@ -185,8 +237,9 @@ def test_mcp_client_without_structured_content(): reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", ) def test_streamable_http_mcp_client(): + """Test comprehensive MCP client with streamable HTTP transport.""" server_thread = threading.Thread( - target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True + target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True ) server_thread.start() time.sleep(2) # wait for server to startup completely @@ -196,12 +249,22 @@ def transport_callback() -> MCPTransport: streamable_http_client = MCPClient(transport_callback) with streamable_http_client: + # Test tools agent = Agent(tools=streamable_http_client.list_tools_sync()) agent("add 1 and 2 using a calculator") tool_use_content_blocks = _messages_to_content_blocks(agent.messages) assert any([block["name"] == "calculator" for block in tool_use_content_blocks]) + # Test prompts + prompts_result = streamable_http_client.list_prompts_sync() + assert len(prompts_result.prompts) >= 2 + + greeting_result = streamable_http_client.get_prompt_sync("greeting_prompt", {"name": "Charlie"}) + assert len(greeting_result.messages) > 0 + prompt_text = greeting_result.messages[0].content.text + assert "Hello, Charlie!" in prompt_text + def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block]