From caeda427793fd466670b9bdaea40aa944eac8053 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 20 Aug 2025 14:11:15 +0100 Subject: [PATCH 01/11] feat: add paginated list decorators for prompts, resources, and tools MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add list_prompts_paginated, list_resources_paginated, and list_tools_paginated decorators to support cursor-based pagination for listing endpoints. These decorators: - Accept a cursor parameter (can be None for first page) - Return the respective ListResult type directly - Maintain backward compatibility with existing non-paginated decorators - Update tool cache for list_tools_paginated Also includes simplified unit tests that verify cursor passthrough. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- src/mcp/server/lowlevel/server.py | 45 +++++++ tests/server/lowlevel/__init__.py | 0 .../server/lowlevel/test_server_pagination.py | 110 ++++++++++++++++++ 3 files changed, 155 insertions(+) create mode 100644 tests/server/lowlevel/__init__.py create mode 100644 tests/server/lowlevel/test_server_pagination.py diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 3076e283e..c98056b86 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -241,6 +241,19 @@ async def handler(_: Any): return decorator + def list_prompts_paginated(self): + def decorator(func: Callable[[types.Cursor | None], Awaitable[types.ListPromptsResult]]): + logger.debug("Registering handler for PromptListRequest with pagination") + + async def handler(req: types.ListPromptsRequest): + result = await func(req.params.cursor if req.params else None) + return types.ServerResult(result) + + self.request_handlers[types.ListPromptsRequest] = handler + return func + + return decorator + def get_prompt(self): def decorator( func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]], @@ -269,6 +282,19 @@ async def handler(_: Any): return decorator + def list_resources_paginated(self): + def decorator(func: Callable[[types.Cursor | None], Awaitable[types.ListResourcesResult]]): + logger.debug("Registering handler for ListResourcesRequest with pagination") + + async def handler(req: types.ListResourcesRequest): + result = await func(req.params.cursor if req.params else None) + return types.ServerResult(result) + + self.request_handlers[types.ListResourcesRequest] = handler + return func + + return decorator + def list_resource_templates(self): def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): logger.debug("Registering handler for ListResourceTemplatesRequest") @@ -396,6 +422,25 @@ async def handler(_: Any): return decorator + def list_tools_paginated(self): + def decorator( + func: Callable[[types.Cursor | None], Awaitable[types.ListToolsResult]] + ): + logger.debug("Registering paginated handler for ListToolsRequest") + + async def handler(request: types.ListToolsRequest): + cursor = request.params.cursor if request.params else None + result = await func(cursor) + # Refresh the tool cache with returned tools + for tool in result.tools: + self._tool_cache[tool.name] = tool + return types.ServerResult(result) + + self.request_handlers[types.ListToolsRequest] = handler + return func + + return decorator + def _make_error_result(self, error_message: str) -> types.ServerResult: """Create a ServerResult with an error CallToolResult.""" return types.ServerResult( diff --git a/tests/server/lowlevel/__init__.py b/tests/server/lowlevel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/lowlevel/test_server_pagination.py b/tests/server/lowlevel/test_server_pagination.py new file mode 100644 index 000000000..d094fd432 --- /dev/null +++ b/tests/server/lowlevel/test_server_pagination.py @@ -0,0 +1,110 @@ +import pytest + +from mcp.server import Server +from mcp.types import ( + Cursor, + ListPromptsRequest, + ListPromptsResult, + ListResourcesRequest, + ListResourcesResult, + ListToolsRequest, + ListToolsResult, + PaginatedRequestParams, + ServerResult, +) + + +@pytest.mark.anyio +async def test_list_prompts_pagination() -> None: + server = Server("test") + test_cursor = "test-cursor-123" + + # Track what cursor was received + received_cursor: Cursor | None = None + + @server.list_prompts_paginated() + async def handle_list_prompts(cursor: Cursor | None) -> ListPromptsResult: + nonlocal received_cursor + received_cursor = cursor + return ListPromptsResult(prompts=[], nextCursor="next") + + handler = server.request_handlers[ListPromptsRequest] + + # Test: No cursor provided -> handler receives None + request = ListPromptsRequest(method="prompts/list", params=None) + result = await handler(request) + assert received_cursor is None + assert isinstance(result, ServerResult) + + # Test: Cursor provided -> handler receives exact cursor value + request_with_cursor = ListPromptsRequest( + method="prompts/list", + params=PaginatedRequestParams(cursor=test_cursor) + ) + result2 = await handler(request_with_cursor) + assert received_cursor == test_cursor + assert isinstance(result2, ServerResult) + + +@pytest.mark.anyio +async def test_list_resources_pagination() -> None: + server = Server("test") + test_cursor = "resource-cursor-456" + + # Track what cursor was received + received_cursor: Cursor | None = None + + @server.list_resources_paginated() + async def handle_list_resources(cursor: Cursor | None) -> ListResourcesResult: + nonlocal received_cursor + received_cursor = cursor + return ListResourcesResult(resources=[], nextCursor="next") + + handler = server.request_handlers[ListResourcesRequest] + + # Test: No cursor provided -> handler receives None + request = ListResourcesRequest(method="resources/list", params=None) + result = await handler(request) + assert received_cursor is None + assert isinstance(result, ServerResult) + + # Test: Cursor provided -> handler receives exact cursor value + request_with_cursor = ListResourcesRequest( + method="resources/list", + params=PaginatedRequestParams(cursor=test_cursor) + ) + result2 = await handler(request_with_cursor) + assert received_cursor == test_cursor + assert isinstance(result2, ServerResult) + + +@pytest.mark.anyio +async def test_list_tools_pagination() -> None: + server = Server("test") + test_cursor = "tools-cursor-789" + + # Track what cursor was received + received_cursor: Cursor | None = None + + @server.list_tools_paginated() + async def handle_list_tools(cursor: Cursor | None) -> ListToolsResult: + nonlocal received_cursor + received_cursor = cursor + return ListToolsResult(tools=[], nextCursor="next") + + handler = server.request_handlers[ListToolsRequest] + + # Test: No cursor provided -> handler receives None + request = ListToolsRequest(method="tools/list", params=None) + result = await handler(request) + assert received_cursor is None + assert isinstance(result, ServerResult) + + # Test: Cursor provided -> handler receives exact cursor value + request_with_cursor = ListToolsRequest( + method="tools/list", + params=PaginatedRequestParams(cursor=test_cursor) + ) + result2 = await handler(request_with_cursor) + assert received_cursor == test_cursor + assert isinstance(result2, ServerResult) From 62cfab1c63429f6d92d8c370fd6861b86c7f7939 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 20 Aug 2025 14:42:41 +0100 Subject: [PATCH 02/11] style: apply ruff formatting to pass pre-commit checks Apply automatic formatting from ruff to ensure code meets project standards: - Remove trailing whitespace - Adjust line breaks for consistency - Format function arguments according to line length limits --- src/mcp/server/lowlevel/server.py | 4 +- .../server/lowlevel/test_server_pagination.py | 43 ++++++++----------- 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index c98056b86..52d116612 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -423,9 +423,7 @@ async def handler(_: Any): return decorator def list_tools_paginated(self): - def decorator( - func: Callable[[types.Cursor | None], Awaitable[types.ListToolsResult]] - ): + def decorator(func: Callable[[types.Cursor | None], Awaitable[types.ListToolsResult]]): logger.debug("Registering paginated handler for ListToolsRequest") async def handler(request: types.ListToolsRequest): diff --git a/tests/server/lowlevel/test_server_pagination.py b/tests/server/lowlevel/test_server_pagination.py index d094fd432..f2c786e45 100644 --- a/tests/server/lowlevel/test_server_pagination.py +++ b/tests/server/lowlevel/test_server_pagination.py @@ -18,29 +18,26 @@ async def test_list_prompts_pagination() -> None: server = Server("test") test_cursor = "test-cursor-123" - + # Track what cursor was received received_cursor: Cursor | None = None - + @server.list_prompts_paginated() async def handle_list_prompts(cursor: Cursor | None) -> ListPromptsResult: nonlocal received_cursor received_cursor = cursor return ListPromptsResult(prompts=[], nextCursor="next") - + handler = server.request_handlers[ListPromptsRequest] - + # Test: No cursor provided -> handler receives None request = ListPromptsRequest(method="prompts/list", params=None) result = await handler(request) assert received_cursor is None assert isinstance(result, ServerResult) - + # Test: Cursor provided -> handler receives exact cursor value - request_with_cursor = ListPromptsRequest( - method="prompts/list", - params=PaginatedRequestParams(cursor=test_cursor) - ) + request_with_cursor = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor=test_cursor)) result2 = await handler(request_with_cursor) assert received_cursor == test_cursor assert isinstance(result2, ServerResult) @@ -50,28 +47,27 @@ async def handle_list_prompts(cursor: Cursor | None) -> ListPromptsResult: async def test_list_resources_pagination() -> None: server = Server("test") test_cursor = "resource-cursor-456" - + # Track what cursor was received received_cursor: Cursor | None = None - + @server.list_resources_paginated() async def handle_list_resources(cursor: Cursor | None) -> ListResourcesResult: nonlocal received_cursor received_cursor = cursor return ListResourcesResult(resources=[], nextCursor="next") - + handler = server.request_handlers[ListResourcesRequest] - + # Test: No cursor provided -> handler receives None request = ListResourcesRequest(method="resources/list", params=None) result = await handler(request) assert received_cursor is None assert isinstance(result, ServerResult) - + # Test: Cursor provided -> handler receives exact cursor value request_with_cursor = ListResourcesRequest( - method="resources/list", - params=PaginatedRequestParams(cursor=test_cursor) + method="resources/list", params=PaginatedRequestParams(cursor=test_cursor) ) result2 = await handler(request_with_cursor) assert received_cursor == test_cursor @@ -82,29 +78,26 @@ async def handle_list_resources(cursor: Cursor | None) -> ListResourcesResult: async def test_list_tools_pagination() -> None: server = Server("test") test_cursor = "tools-cursor-789" - + # Track what cursor was received received_cursor: Cursor | None = None - + @server.list_tools_paginated() async def handle_list_tools(cursor: Cursor | None) -> ListToolsResult: nonlocal received_cursor received_cursor = cursor return ListToolsResult(tools=[], nextCursor="next") - + handler = server.request_handlers[ListToolsRequest] - + # Test: No cursor provided -> handler receives None request = ListToolsRequest(method="tools/list", params=None) result = await handler(request) assert received_cursor is None assert isinstance(result, ServerResult) - + # Test: Cursor provided -> handler receives exact cursor value - request_with_cursor = ListToolsRequest( - method="tools/list", - params=PaginatedRequestParams(cursor=test_cursor) - ) + request_with_cursor = ListToolsRequest(method="tools/list", params=PaginatedRequestParams(cursor=test_cursor)) result2 = await handler(request_with_cursor) assert received_cursor == test_cursor assert isinstance(result2, ServerResult) From b7e6a0c7b6cd8b334205c4fa52ccedfef2fb79e6 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 20 Aug 2025 16:25:24 +0100 Subject: [PATCH 03/11] feat: add pagination examples and documentation - Create mcp_simple_pagination example server demonstrating all three paginated endpoints - Add pagination snippets for both server and client implementations - Update README to use snippet-source pattern for pagination examples - Move mutually exclusive note to blockquote format for better visibility - Complete example shows tools, resources, and prompts pagination with different page sizes --- README.md | 116 +++++++++ examples/servers/simple-pagination/README.md | 77 ++++++ .../mcp_simple_pagination/__init__.py | 0 .../mcp_simple_pagination/__main__.py | 5 + .../mcp_simple_pagination/server.py | 225 ++++++++++++++++++ .../servers/simple-pagination/pyproject.toml | 47 ++++ .../snippets/clients/pagination_client.py | 41 ++++ .../snippets/servers/pagination_example.py | 35 +++ uv.lock | 36 ++- 9 files changed, 581 insertions(+), 1 deletion(-) create mode 100644 examples/servers/simple-pagination/README.md create mode 100644 examples/servers/simple-pagination/mcp_simple_pagination/__init__.py create mode 100644 examples/servers/simple-pagination/mcp_simple_pagination/__main__.py create mode 100644 examples/servers/simple-pagination/mcp_simple_pagination/server.py create mode 100644 examples/servers/simple-pagination/pyproject.toml create mode 100644 examples/snippets/clients/pagination_client.py create mode 100644 examples/snippets/servers/pagination_example.py diff --git a/README.md b/README.md index d2fb9194a..fda988898 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ - [Advanced Usage](#advanced-usage) - [Low-Level Server](#low-level-server) - [Structured Output Support](#structured-output-support) + - [Pagination (Advanced)](#pagination-advanced) - [Writing MCP Clients](#writing-mcp-clients) - [Client Display Utilities](#client-display-utilities) - [OAuth Authentication for Clients](#oauth-authentication-for-clients) @@ -1737,6 +1738,121 @@ Tools can return data in three ways: When an `outputSchema` is defined, the server automatically validates the structured output against the schema. This ensures type safety and helps catch errors early. +### Pagination (Advanced) + +For servers that need to handle large datasets, the low-level server provides paginated versions of list operations. This is an optional optimization - most servers won't need pagination unless they're dealing with hundreds or thousands of items. + +#### Server-side Implementation + + +```python +""" +Example of implementing pagination with MCP server decorators. +""" + +from pydantic import AnyUrl + +import mcp.types as types +from mcp.server.lowlevel import Server + +# Initialize the server +server = Server("paginated-server") + +# Sample data to paginate +ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items + + +@server.list_resources_paginated() +async def list_resources_paginated(cursor: types.Cursor | None) -> types.ListResourcesResult: + """List resources with pagination support.""" + page_size = 10 + + # Parse cursor to get offset + start = 0 if cursor is None else int(cursor) + end = start + page_size + + # Get page of resources + page_items = [ + types.Resource(uri=AnyUrl(f"resource://items/{item}"), name=item, description=f"Description for {item}") + for item in ITEMS[start:end] + ] + + # Determine next cursor + next_cursor = str(end) if end < len(ITEMS) else None + + return types.ListResourcesResult(resources=page_items, nextCursor=next_cursor) +``` + +_Full example: [examples/snippets/servers/pagination_example.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/pagination_example.py)_ + + +Similar decorators are available for all list operations: + +- `@server.list_tools_paginated()` - for paginating tools +- `@server.list_resources_paginated()` - for paginating resources +- `@server.list_prompts_paginated()` - for paginating prompts + +#### Client-side Consumption + + +```python +""" +Example of consuming paginated MCP endpoints from a client. +""" + +import asyncio + +from mcp.client.session import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.types import Resource + + +async def list_all_resources() -> None: + """Fetch all resources using pagination.""" + async with stdio_client(StdioServerParameters(command="uv", args=["run", "mcp-simple-pagination"])) as ( + read, + write, + ): + async with ClientSession(read, write) as session: + await session.initialize() + + all_resources: list[Resource] = [] + cursor = None + + while True: + # Fetch a page of resources + result = await session.list_resources(cursor=cursor) + all_resources.extend(result.resources) + + print(f"Fetched {len(result.resources)} resources") + + # Check if there are more pages + if result.nextCursor: + cursor = result.nextCursor + else: + break + + print(f"Total resources: {len(all_resources)}") + + +if __name__ == "__main__": + asyncio.run(list_all_resources()) +``` + +_Full example: [examples/snippets/clients/pagination_client.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/pagination_client.py)_ + + +#### Key Points + +- **Cursors are opaque strings** - the server defines the format (numeric offsets, timestamps, etc.) +- **Return `nextCursor=None`** when there are no more pages +- **Backward compatible** - clients that don't support pagination will still work (they'll just get the first page) +- **Flexible page sizes** - Each endpoint can define its own page size based on data characteristics + +> **NOTE**: The paginated decorators (`list_tools_paginated()`, `list_resources_paginated()`, `list_prompts_paginated()`) are mutually exclusive with their non-paginated counterparts and cannot be used together on the same server instance. + +See the [simple-pagination example](examples/servers/simple-pagination) for a complete implementation. + ### Writing MCP Clients The SDK provides a high-level client interface for connecting to MCP servers using various [transports](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports): diff --git a/examples/servers/simple-pagination/README.md b/examples/servers/simple-pagination/README.md new file mode 100644 index 000000000..e732b8efb --- /dev/null +++ b/examples/servers/simple-pagination/README.md @@ -0,0 +1,77 @@ +# MCP Simple Pagination + +A simple MCP server demonstrating pagination for tools, resources, and prompts using cursor-based pagination. + +## Usage + +Start the server using either stdio (default) or SSE transport: + +```bash +# Using stdio transport (default) +uv run mcp-simple-pagination + +# Using SSE transport on custom port +uv run mcp-simple-pagination --transport sse --port 8000 +``` + +The server exposes: + +- 25 tools (paginated, 5 per page) +- 30 resources (paginated, 10 per page) +- 20 prompts (paginated, 7 per page) + +Each paginated list returns a `nextCursor` when more pages are available. Use this cursor in subsequent requests to retrieve the next page. + +## Example + +Using the MCP client, you can retrieve paginated items like this using the STDIO transport: + +```python +import asyncio +from mcp.client.session import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client + + +async def main(): + async with stdio_client( + StdioServerParameters(command="uv", args=["run", "mcp-simple-pagination"]) + ) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # Get first page of tools + tools_page1 = await session.list_tools() + print(f"First page: {len(tools_page1.tools)} tools") + print(f"Next cursor: {tools_page1.nextCursor}") + + # Get second page using cursor + if tools_page1.nextCursor: + tools_page2 = await session.list_tools(cursor=tools_page1.nextCursor) + print(f"Second page: {len(tools_page2.tools)} tools") + + # Similarly for resources + resources_page1 = await session.list_resources() + print(f"First page: {len(resources_page1.resources)} resources") + + # And for prompts + prompts_page1 = await session.list_prompts() + print(f"First page: {len(prompts_page1.prompts)} prompts") + + +asyncio.run(main()) +``` + +## Pagination Details + +The server uses simple numeric indices as cursors for demonstration purposes. In production scenarios, you might use: + +- Database offsets or row IDs +- Timestamps for time-based pagination +- Opaque tokens encoding pagination state + +The pagination implementation demonstrates: + +- Handling `None` cursor for the first page +- Returning `nextCursor` when more data exists +- Gracefully handling invalid cursors +- Different page sizes for different resource types diff --git a/examples/servers/simple-pagination/mcp_simple_pagination/__init__.py b/examples/servers/simple-pagination/mcp_simple_pagination/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/servers/simple-pagination/mcp_simple_pagination/__main__.py b/examples/servers/simple-pagination/mcp_simple_pagination/__main__.py new file mode 100644 index 000000000..e7ef16530 --- /dev/null +++ b/examples/servers/simple-pagination/mcp_simple_pagination/__main__.py @@ -0,0 +1,5 @@ +import sys + +from .server import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-pagination/mcp_simple_pagination/server.py b/examples/servers/simple-pagination/mcp_simple_pagination/server.py new file mode 100644 index 000000000..8a7ebdb9c --- /dev/null +++ b/examples/servers/simple-pagination/mcp_simple_pagination/server.py @@ -0,0 +1,225 @@ +""" +Simple MCP server demonstrating pagination for tools, resources, and prompts. + +This example shows how to use the paginated decorators to handle large lists +of items that need to be split across multiple pages. +""" + +from typing import Any + +import anyio +import click +import mcp.types as types +from mcp.server.lowlevel import Server +from pydantic import AnyUrl +from starlette.requests import Request + +# Sample data - in real scenarios, this might come from a database +SAMPLE_TOOLS = [ + types.Tool( + name=f"tool_{i}", + title=f"Tool {i}", + description=f"This is sample tool number {i}", + inputSchema={"type": "object", "properties": {"input": {"type": "string"}}}, + ) + for i in range(1, 26) # 25 tools total +] + +SAMPLE_RESOURCES = [ + types.Resource( + uri=AnyUrl(f"file:///path/to/resource_{i}.txt"), + name=f"resource_{i}", + description=f"This is sample resource number {i}", + ) + for i in range(1, 31) # 30 resources total +] + +SAMPLE_PROMPTS = [ + types.Prompt( + name=f"prompt_{i}", + description=f"This is sample prompt number {i}", + arguments=[ + types.PromptArgument(name="arg1", description="First argument", required=True), + ], + ) + for i in range(1, 21) # 20 prompts total +] + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option( + "--transport", + type=click.Choice(["stdio", "sse"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server("mcp-simple-pagination") + + # Paginated list_tools - returns 5 tools per page + @app.list_tools_paginated() + async def list_tools_paginated(cursor: types.Cursor | None) -> types.ListToolsResult: + page_size = 5 + + if cursor is None: + # First page + start_idx = 0 + else: + # Parse cursor to get the start index + try: + start_idx = int(cursor) + except (ValueError, TypeError): + # Invalid cursor, return empty + return types.ListToolsResult(tools=[], nextCursor=None) + + # Get the page of tools + page_tools = SAMPLE_TOOLS[start_idx : start_idx + page_size] + + # Determine if there are more pages + next_cursor = None + if start_idx + page_size < len(SAMPLE_TOOLS): + next_cursor = str(start_idx + page_size) + + return types.ListToolsResult(tools=page_tools, nextCursor=next_cursor) + + # Paginated list_resources - returns 10 resources per page + @app.list_resources_paginated() + async def list_resources_paginated( + cursor: types.Cursor | None, + ) -> types.ListResourcesResult: + page_size = 10 + + if cursor is None: + # First page + start_idx = 0 + else: + # Parse cursor to get the start index + try: + start_idx = int(cursor) + except (ValueError, TypeError): + # Invalid cursor, return empty + return types.ListResourcesResult(resources=[], nextCursor=None) + + # Get the page of resources + page_resources = SAMPLE_RESOURCES[start_idx : start_idx + page_size] + + # Determine if there are more pages + next_cursor = None + if start_idx + page_size < len(SAMPLE_RESOURCES): + next_cursor = str(start_idx + page_size) + + return types.ListResourcesResult(resources=page_resources, nextCursor=next_cursor) + + # Paginated list_prompts - returns 7 prompts per page + @app.list_prompts_paginated() + async def list_prompts_paginated( + cursor: types.Cursor | None, + ) -> types.ListPromptsResult: + page_size = 7 + + if cursor is None: + # First page + start_idx = 0 + else: + # Parse cursor to get the start index + try: + start_idx = int(cursor) + except (ValueError, TypeError): + # Invalid cursor, return empty + return types.ListPromptsResult(prompts=[], nextCursor=None) + + # Get the page of prompts + page_prompts = SAMPLE_PROMPTS[start_idx : start_idx + page_size] + + # Determine if there are more pages + next_cursor = None + if start_idx + page_size < len(SAMPLE_PROMPTS): + next_cursor = str(start_idx + page_size) + + return types.ListPromptsResult(prompts=page_prompts, nextCursor=next_cursor) + + # Implement call_tool handler + @app.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + # Find the tool in our sample data + tool = next((t for t in SAMPLE_TOOLS if t.name == name), None) + if not tool: + raise ValueError(f"Unknown tool: {name}") + + # Simple mock response + return [ + types.TextContent( + type="text", + text=f"Called tool '{name}' with arguments: {arguments}", + ) + ] + + # Implement read_resource handler + @app.read_resource() + async def read_resource(uri: AnyUrl) -> str: + # Find the resource in our sample data + resource = next((r for r in SAMPLE_RESOURCES if r.uri == uri), None) + if not resource: + raise ValueError(f"Unknown resource: {uri}") + + # Return a simple string - the decorator will convert it to TextResourceContents + return f"Content of {resource.name}: This is sample content for the resource." + + # Implement get_prompt handler + @app.get_prompt() + async def get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: + # Find the prompt in our sample data + prompt = next((p for p in SAMPLE_PROMPTS if p.name == name), None) + if not prompt: + raise ValueError(f"Unknown prompt: {name}") + + # Simple mock response + message_text = f"This is the prompt '{name}'" + if arguments: + message_text += f" with arguments: {arguments}" + + return types.GetPromptResult( + description=prompt.description, + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text=message_text), + ) + ], + ) + + if transport == "sse": + from mcp.server.sse import SseServerTransport + from starlette.applications import Starlette + from starlette.responses import Response + from starlette.routing import Mount, Route + + sse = SseServerTransport("/messages/") + + async def handle_sse(request: Request): + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] + await app.run(streams[0], streams[1], app.create_initialization_options()) + return Response() + + starlette_app = Starlette( + debug=True, + routes=[ + Route("/sse", endpoint=handle_sse, methods=["GET"]), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + + import uvicorn + + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + else: + from mcp.server.stdio import stdio_server + + async def arun(): + async with stdio_server() as streams: + await app.run(streams[0], streams[1], app.create_initialization_options()) + + anyio.run(arun) + + return 0 diff --git a/examples/servers/simple-pagination/pyproject.toml b/examples/servers/simple-pagination/pyproject.toml new file mode 100644 index 000000000..0c60cf73c --- /dev/null +++ b/examples/servers/simple-pagination/pyproject.toml @@ -0,0 +1,47 @@ +[project] +name = "mcp-simple-pagination" +version = "0.1.0" +description = "A simple MCP server demonstrating pagination for tools, resources, and prompts" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +maintainers = [ + { name = "David Soria Parra", email = "davidsp@anthropic.com" }, + { name = "Justin Spahr-Summers", email = "justin@anthropic.com" }, +] +keywords = ["mcp", "llm", "automation", "pagination", "cursor"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["anyio>=4.5", "click>=8.2.0", "httpx>=0.27", "mcp"] + +[project.scripts] +mcp-simple-pagination = "mcp_simple_pagination.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_pagination"] + +[tool.pyright] +include = ["mcp_simple_pagination"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[tool.uv] +dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] \ No newline at end of file diff --git a/examples/snippets/clients/pagination_client.py b/examples/snippets/clients/pagination_client.py new file mode 100644 index 000000000..4df1aec60 --- /dev/null +++ b/examples/snippets/clients/pagination_client.py @@ -0,0 +1,41 @@ +""" +Example of consuming paginated MCP endpoints from a client. +""" + +import asyncio + +from mcp.client.session import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.types import Resource + + +async def list_all_resources() -> None: + """Fetch all resources using pagination.""" + async with stdio_client(StdioServerParameters(command="uv", args=["run", "mcp-simple-pagination"])) as ( + read, + write, + ): + async with ClientSession(read, write) as session: + await session.initialize() + + all_resources: list[Resource] = [] + cursor = None + + while True: + # Fetch a page of resources + result = await session.list_resources(cursor=cursor) + all_resources.extend(result.resources) + + print(f"Fetched {len(result.resources)} resources") + + # Check if there are more pages + if result.nextCursor: + cursor = result.nextCursor + else: + break + + print(f"Total resources: {len(all_resources)}") + + +if __name__ == "__main__": + asyncio.run(list_all_resources()) diff --git a/examples/snippets/servers/pagination_example.py b/examples/snippets/servers/pagination_example.py new file mode 100644 index 000000000..3852a209c --- /dev/null +++ b/examples/snippets/servers/pagination_example.py @@ -0,0 +1,35 @@ +""" +Example of implementing pagination with MCP server decorators. +""" + +from pydantic import AnyUrl + +import mcp.types as types +from mcp.server.lowlevel import Server + +# Initialize the server +server = Server("paginated-server") + +# Sample data to paginate +ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items + + +@server.list_resources_paginated() +async def list_resources_paginated(cursor: types.Cursor | None) -> types.ListResourcesResult: + """List resources with pagination support.""" + page_size = 10 + + # Parse cursor to get offset + start = 0 if cursor is None else int(cursor) + end = start + page_size + + # Get page of resources + page_items = [ + types.Resource(uri=AnyUrl(f"resource://items/{item}"), name=item, description=f"Description for {item}") + for item in ITEMS[start:end] + ] + + # Determine next cursor + next_cursor = str(end) if end < len(ITEMS) else None + + return types.ListResourcesResult(resources=page_items, nextCursor=next_cursor) diff --git a/uv.lock b/uv.lock index 7979f9aab..68abdcc4f 100644 --- a/uv.lock +++ b/uv.lock @@ -1,11 +1,12 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" [manifest] members = [ "mcp", "mcp-simple-auth", + "mcp-simple-pagination", "mcp-simple-prompt", "mcp-simple-resource", "mcp-simple-streamablehttp", @@ -730,6 +731,39 @@ dev = [ { name = "ruff", specifier = ">=0.8.5" }, ] +[[package]] +name = "mcp-simple-pagination" +version = "0.1.0" +source = { editable = "examples/servers/simple-pagination" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.2.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-prompt" version = "0.1.0" From 203b3ad879cb369c624f11fcb2f74a6b2e46a64c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 22 Aug 2025 17:24:57 +0100 Subject: [PATCH 04/11] switch pagination to single decorator with callback inspection --- README.md | 10 +- .../mcp_simple_pagination/server.py | 6 +- .../snippets/servers/pagination_example.py | 2 +- src/mcp/server/lowlevel/func_inspection.py | 54 ++++++ src/mcp/server/lowlevel/server.py | 116 +++++++------ tests/server/lowlevel/test_func_inspection.py | 141 +++++++++++++++ tests/server/lowlevel/test_server_listing.py | 162 ++++++++++++++++++ .../server/lowlevel/test_server_pagination.py | 6 +- 8 files changed, 429 insertions(+), 68 deletions(-) create mode 100644 src/mcp/server/lowlevel/func_inspection.py create mode 100644 tests/server/lowlevel/test_func_inspection.py create mode 100644 tests/server/lowlevel/test_server_listing.py diff --git a/README.md b/README.md index fda988898..95c871a4e 100644 --- a/README.md +++ b/README.md @@ -1762,7 +1762,7 @@ server = Server("paginated-server") ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items -@server.list_resources_paginated() +@server.list_resources() async def list_resources_paginated(cursor: types.Cursor | None) -> types.ListResourcesResult: """List resources with pagination support.""" page_size = 10 @@ -1786,12 +1786,6 @@ async def list_resources_paginated(cursor: types.Cursor | None) -> types.ListRes _Full example: [examples/snippets/servers/pagination_example.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/pagination_example.py)_ -Similar decorators are available for all list operations: - -- `@server.list_tools_paginated()` - for paginating tools -- `@server.list_resources_paginated()` - for paginating resources -- `@server.list_prompts_paginated()` - for paginating prompts - #### Client-side Consumption @@ -1849,8 +1843,6 @@ _Full example: [examples/snippets/clients/pagination_client.py](https://github.c - **Backward compatible** - clients that don't support pagination will still work (they'll just get the first page) - **Flexible page sizes** - Each endpoint can define its own page size based on data characteristics -> **NOTE**: The paginated decorators (`list_tools_paginated()`, `list_resources_paginated()`, `list_prompts_paginated()`) are mutually exclusive with their non-paginated counterparts and cannot be used together on the same server instance. - See the [simple-pagination example](examples/servers/simple-pagination) for a complete implementation. ### Writing MCP Clients diff --git a/examples/servers/simple-pagination/mcp_simple_pagination/server.py b/examples/servers/simple-pagination/mcp_simple_pagination/server.py index 8a7ebdb9c..97f545718 100644 --- a/examples/servers/simple-pagination/mcp_simple_pagination/server.py +++ b/examples/servers/simple-pagination/mcp_simple_pagination/server.py @@ -58,7 +58,7 @@ def main(port: int, transport: str) -> int: app = Server("mcp-simple-pagination") # Paginated list_tools - returns 5 tools per page - @app.list_tools_paginated() + @app.list_tools() async def list_tools_paginated(cursor: types.Cursor | None) -> types.ListToolsResult: page_size = 5 @@ -84,7 +84,7 @@ async def list_tools_paginated(cursor: types.Cursor | None) -> types.ListToolsRe return types.ListToolsResult(tools=page_tools, nextCursor=next_cursor) # Paginated list_resources - returns 10 resources per page - @app.list_resources_paginated() + @app.list_resources() async def list_resources_paginated( cursor: types.Cursor | None, ) -> types.ListResourcesResult: @@ -112,7 +112,7 @@ async def list_resources_paginated( return types.ListResourcesResult(resources=page_resources, nextCursor=next_cursor) # Paginated list_prompts - returns 7 prompts per page - @app.list_prompts_paginated() + @app.list_prompts() async def list_prompts_paginated( cursor: types.Cursor | None, ) -> types.ListPromptsResult: diff --git a/examples/snippets/servers/pagination_example.py b/examples/snippets/servers/pagination_example.py index 3852a209c..c8c99323c 100644 --- a/examples/snippets/servers/pagination_example.py +++ b/examples/snippets/servers/pagination_example.py @@ -14,7 +14,7 @@ ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items -@server.list_resources_paginated() +@server.list_resources() async def list_resources_paginated(cursor: types.Cursor | None) -> types.ListResourcesResult: """List resources with pagination support.""" page_size = 10 diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py new file mode 100644 index 000000000..9573ee0ae --- /dev/null +++ b/src/mcp/server/lowlevel/func_inspection.py @@ -0,0 +1,54 @@ +import inspect +from collections.abc import Callable +from typing import Any + + +def accepts_cursor(func: Callable[..., Any]) -> bool: + """ + True if the function accepts a cursor parameter call, otherwise false. + + `accepts_cursor` does not validate that the function will work. For + example, if `func` contains keyword-only arguments with no defaults, + then it will not work when used in the `lowlevel/server.py` code, but + this function will not raise an exception. + """ + try: + sig = inspect.signature(func) + except (ValueError, TypeError): + return False + + params = dict(sig.parameters.items()) + + method = inspect.ismethod(func) + + if method: + params.pop("self", None) + params.pop("cls", None) + + if len(params) == 0: + # No parameters at all - can't accept cursor + return False + + # Check if ALL remaining parameters are keyword-only + all_keyword_only = all(param.kind == inspect.Parameter.KEYWORD_ONLY for param in params.values()) + + if all_keyword_only: + # If all params are keyword-only, check if they ALL have defaults + # If they do, the function can be called with no arguments -> no cursor + all_have_defaults = all(param.default is not inspect.Parameter.empty for param in params.values()) + return not all_have_defaults # False if all have defaults (no cursor), True otherwise + + # Check if the ONLY parameter is **kwargs (VAR_KEYWORD) + # A function with only **kwargs can't accept a positional cursor argument + if len(params) == 1: + only_param = next(iter(params.values())) + if only_param.kind == inspect.Parameter.VAR_KEYWORD: + return False # Can't pass positional cursor to **kwargs + + # Has at least one positional or variadic parameter - can accept cursor + # Important note: this is designed to _not_ handle the situation where + # there are multiple keyword only arguments with no defaults. In those + # situations it's an invalid handler function, and will error. But it's + # not the responsibility of this function to check the validity of a + # callback. + return True diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 52d116612..fe36f28f1 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -82,6 +82,7 @@ async def main(): from typing_extensions import TypeVar import mcp.types as types +from mcp.server.lowlevel.func_inspection import accepts_cursor from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -229,25 +230,29 @@ def request_context( return request_ctx.get() def list_prompts(self): - def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]): + def decorator( + func: Callable[[], Awaitable[list[types.Prompt]]] + | Callable[[types.Cursor | None], Awaitable[types.ListPromptsResult]], + ): logger.debug("Registering handler for PromptListRequest") + pass_cursor = accepts_cursor(func) - async def handler(_: Any): - prompts = await func() - return types.ServerResult(types.ListPromptsResult(prompts=prompts)) + if pass_cursor: + cursor_func = cast(Callable[[types.Cursor | None], Awaitable[types.ListPromptsResult]], func) - self.request_handlers[types.ListPromptsRequest] = handler - return func + async def cursor_handler(req: types.ListPromptsRequest): + result = await cursor_func(req.params.cursor if req.params is not None else None) + return types.ServerResult(result) - return decorator + handler = cursor_handler + else: + list_func = cast(Callable[[], Awaitable[list[types.Prompt]]], func) - def list_prompts_paginated(self): - def decorator(func: Callable[[types.Cursor | None], Awaitable[types.ListPromptsResult]]): - logger.debug("Registering handler for PromptListRequest with pagination") + async def list_handler(_: types.ListPromptsRequest): + result = await list_func() + return types.ServerResult(types.ListPromptsResult(prompts=result)) - async def handler(req: types.ListPromptsRequest): - result = await func(req.params.cursor if req.params else None) - return types.ServerResult(result) + handler = list_handler self.request_handlers[types.ListPromptsRequest] = handler return func @@ -270,25 +275,29 @@ async def handler(req: types.GetPromptRequest): return decorator def list_resources(self): - def decorator(func: Callable[[], Awaitable[list[types.Resource]]]): + def decorator( + func: Callable[[], Awaitable[list[types.Resource]]] + | Callable[[types.Cursor | None], Awaitable[types.ListResourcesResult]], + ): logger.debug("Registering handler for ListResourcesRequest") + pass_cursor = accepts_cursor(func) - async def handler(_: Any): - resources = await func() - return types.ServerResult(types.ListResourcesResult(resources=resources)) + if pass_cursor: + cursor_func = cast(Callable[[types.Cursor | None], Awaitable[types.ListResourcesResult]], func) - self.request_handlers[types.ListResourcesRequest] = handler - return func + async def cursor_handler(req: types.ListResourcesRequest): + result = await cursor_func(req.params.cursor if req.params is not None else None) + return types.ServerResult(result) - return decorator + handler = cursor_handler + else: + list_func = cast(Callable[[], Awaitable[list[types.Resource]]], func) - def list_resources_paginated(self): - def decorator(func: Callable[[types.Cursor | None], Awaitable[types.ListResourcesResult]]): - logger.debug("Registering handler for ListResourcesRequest with pagination") + async def list_handler(_: types.ListResourcesRequest): + result = await list_func() + return types.ServerResult(types.ListResourcesResult(resources=result)) - async def handler(req: types.ListResourcesRequest): - result = await func(req.params.cursor if req.params else None) - return types.ServerResult(result) + handler = list_handler self.request_handlers[types.ListResourcesRequest] = handler return func @@ -406,33 +415,36 @@ async def handler(req: types.UnsubscribeRequest): return decorator def list_tools(self): - def decorator(func: Callable[[], Awaitable[list[types.Tool]]]): + def decorator( + func: Callable[[], Awaitable[list[types.Tool]]] + | Callable[[types.Cursor | None], Awaitable[types.ListToolsResult]], + ): logger.debug("Registering handler for ListToolsRequest") - - async def handler(_: Any): - tools = await func() - # Refresh the tool cache - self._tool_cache.clear() - for tool in tools: - self._tool_cache[tool.name] = tool - return types.ServerResult(types.ListToolsResult(tools=tools)) - - self.request_handlers[types.ListToolsRequest] = handler - return func - - return decorator - - def list_tools_paginated(self): - def decorator(func: Callable[[types.Cursor | None], Awaitable[types.ListToolsResult]]): - logger.debug("Registering paginated handler for ListToolsRequest") - - async def handler(request: types.ListToolsRequest): - cursor = request.params.cursor if request.params else None - result = await func(cursor) - # Refresh the tool cache with returned tools - for tool in result.tools: - self._tool_cache[tool.name] = tool - return types.ServerResult(result) + pass_cursor = accepts_cursor(func) + + if pass_cursor: + cursor_func = cast(Callable[[types.Cursor | None], Awaitable[types.ListToolsResult]], func) + + async def cursor_handler(req: types.ListToolsRequest): + result = await cursor_func(req.params.cursor if req.params is not None else None) + # Refresh the tool cache with returned tools + for tool in result.tools: + self._tool_cache[tool.name] = tool + return types.ServerResult(result) + + handler = cursor_handler + else: + list_func = cast(Callable[[], Awaitable[list[types.Tool]]], func) + + async def list_handler(req: types.ListToolsRequest): + result = await list_func() + # Clear and refresh the entire tool cache + self._tool_cache.clear() + for tool in result: + self._tool_cache[tool.name] = tool + return types.ServerResult(types.ListToolsResult(tools=result)) + + handler = list_handler self.request_handlers[types.ListToolsRequest] = handler return func diff --git a/tests/server/lowlevel/test_func_inspection.py b/tests/server/lowlevel/test_func_inspection.py new file mode 100644 index 000000000..4114b329d --- /dev/null +++ b/tests/server/lowlevel/test_func_inspection.py @@ -0,0 +1,141 @@ +from collections.abc import Callable +from typing import Any + +import pytest + +from mcp import types +from mcp.server.lowlevel.func_inspection import accepts_cursor + + +# Test fixtures - functions and methods with various signatures +class MyClass: + async def no_cursor_method(self): + """Instance method without cursor parameter""" + pass + + async def cursor_method(self, cursor: types.Cursor | None): + """Instance method with cursor parameter""" + pass + + @classmethod + async def no_cursor_class_method(cls): + """Class method without cursor parameter""" + pass + + @classmethod + async def cursor_class_method(cls, cursor: types.Cursor | None): + """Class method with cursor parameter""" + pass + + @staticmethod + async def no_cursor_static_method(): + """Static method without cursor parameter""" + pass + + @staticmethod + async def cursor_static_method(cursor: types.Cursor | None): + """Static method with cursor parameter""" + pass + + +async def no_cursor_func(): + """Function without cursor parameter""" + pass + + +async def cursor_func(cursor: types.Cursor | None): + """Function with cursor parameter""" + pass + + +async def cursor_func_different_name(c: types.Cursor | None): + """Function with cursor parameter but different arg name""" + pass + + +async def cursor_func_with_self(self: types.Cursor | None): + """Function with parameter named 'self' (edge case)""" + pass + + +async def var_positional_func(*args: Any): + """Function with *args""" + pass + + +async def positional_with_var_positional_func(cursor: types.Cursor | None, *args: Any): + """Function with cursor and *args""" + pass + + +async def var_keyword_func(**kwargs: Any): + """Function with **kwargs""" + pass + + +async def cursor_with_var_keyword_func(cursor: types.Cursor | None, **kwargs: Any): + """Function with cursor and **kwargs""" + pass + + +async def cursor_with_default(cursor: types.Cursor | None = None): + """Function with cursor parameter having default value""" + pass + + +async def keyword_only_with_defaults(*, cursor: types.Cursor | None = None): + """Function with keyword-only cursor with default""" + pass + + +async def keyword_only_multiple_all_defaults(*, a: str = "test", b: int = 42): + """Function with multiple keyword-only params all with defaults""" + pass + + +async def mixed_positional_and_keyword(cursor: types.Cursor | None, *, extra: str = "test"): + """Function with positional and keyword-only params""" + pass + + +@pytest.mark.parametrize( + "callable_obj,expected,description", + [ + # Regular functions + (no_cursor_func, False, "function without parameters"), + (cursor_func, True, "function with cursor parameter"), + (cursor_func_different_name, True, "function with cursor (different param name)"), + (cursor_func_with_self, True, "function with param named 'self'"), + # Instance methods + (MyClass().no_cursor_method, False, "instance method without cursor"), + (MyClass().cursor_method, True, "instance method with cursor"), + # Class methods + (MyClass.no_cursor_class_method, False, "class method without cursor"), + (MyClass.cursor_class_method, True, "class method with cursor"), + # Static methods + (MyClass.no_cursor_static_method, False, "static method without cursor"), + (MyClass.cursor_static_method, True, "static method with cursor"), + # Variadic parameters + (var_positional_func, True, "function with *args"), + (positional_with_var_positional_func, True, "function with cursor and *args"), + (var_keyword_func, False, "function with **kwargs"), + (cursor_with_var_keyword_func, True, "function with cursor and **kwargs"), + # Edge cases + (cursor_with_default, True, "function with cursor having default value"), + # Keyword-only parameters + (keyword_only_with_defaults, False, "keyword-only with default (can call with no args)"), + (keyword_only_multiple_all_defaults, False, "multiple keyword-only all with defaults"), + (mixed_positional_and_keyword, True, "mixed positional and keyword-only params"), + ], + ids=lambda x: x if isinstance(x, str) else "", +) +def test_accepts_cursor(callable_obj: Callable[..., Any], expected: bool, description: str): + """Test that accepts_cursor correctly identifies functions that accept a cursor parameter. + + The function should return True if the callable can potentially accept a positional + cursor argument. Returns False if: + - No parameters at all + - Only keyword-only parameters that ALL have defaults (can call with no args) + - Only **kwargs parameter (can't accept positional arguments) + """ + assert accepts_cursor(callable_obj) == expected, f"Failed for {description}" diff --git a/tests/server/lowlevel/test_server_listing.py b/tests/server/lowlevel/test_server_listing.py new file mode 100644 index 000000000..9474edb3f --- /dev/null +++ b/tests/server/lowlevel/test_server_listing.py @@ -0,0 +1,162 @@ +"""Basic tests for list_prompts, list_resources, and list_tools decorators without pagination.""" + +import pytest +from pydantic import AnyUrl + +from mcp.server import Server +from mcp.types import ( + ListPromptsRequest, + ListPromptsResult, + ListResourcesRequest, + ListResourcesResult, + ListToolsRequest, + ListToolsResult, + Prompt, + Resource, + ServerResult, + Tool, +) + + +@pytest.mark.anyio +async def test_list_prompts_basic() -> None: + """Test basic prompt listing without pagination.""" + server = Server("test") + + test_prompts = [ + Prompt(name="prompt1", description="First prompt"), + Prompt(name="prompt2", description="Second prompt"), + ] + + @server.list_prompts() + async def handle_list_prompts() -> list[Prompt]: + return test_prompts + + handler = server.request_handlers[ListPromptsRequest] + request = ListPromptsRequest(method="prompts/list", params=None) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListPromptsResult) + assert result.root.prompts == test_prompts + + +@pytest.mark.anyio +async def test_list_resources_basic() -> None: + """Test basic resource listing without pagination.""" + server = Server("test") + + test_resources = [ + Resource(uri=AnyUrl("file:///test1.txt"), name="Test 1"), + Resource(uri=AnyUrl("file:///test2.txt"), name="Test 2"), + ] + + @server.list_resources() + async def handle_list_resources() -> list[Resource]: + return test_resources + + handler = server.request_handlers[ListResourcesRequest] + request = ListResourcesRequest(method="resources/list", params=None) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListResourcesResult) + assert result.root.resources == test_resources + + +@pytest.mark.anyio +async def test_list_tools_basic() -> None: + """Test basic tool listing without pagination.""" + server = Server("test") + + test_tools = [ + Tool( + name="tool1", + description="First tool", + inputSchema={ + "type": "object", + "properties": { + "message": {"type": "string"}, + }, + "required": ["message"], + }, + ), + Tool( + name="tool2", + description="Second tool", + inputSchema={ + "type": "object", + "properties": { + "count": {"type": "number"}, + "enabled": {"type": "boolean"}, + }, + "required": ["count"], + }, + ), + ] + + @server.list_tools() + async def handle_list_tools() -> list[Tool]: + return test_tools + + handler = server.request_handlers[ListToolsRequest] + request = ListToolsRequest(method="tools/list", params=None) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListToolsResult) + assert result.root.tools == test_tools + + +@pytest.mark.anyio +async def test_list_prompts_empty() -> None: + """Test listing with empty results.""" + server = Server("test") + + @server.list_prompts() + async def handle_list_prompts() -> list[Prompt]: + return [] + + handler = server.request_handlers[ListPromptsRequest] + request = ListPromptsRequest(method="prompts/list", params=None) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListPromptsResult) + assert result.root.prompts == [] + + +@pytest.mark.anyio +async def test_list_resources_empty() -> None: + """Test listing with empty results.""" + server = Server("test") + + @server.list_resources() + async def handle_list_resources() -> list[Resource]: + return [] + + handler = server.request_handlers[ListResourcesRequest] + request = ListResourcesRequest(method="resources/list", params=None) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListResourcesResult) + assert result.root.resources == [] + + +@pytest.mark.anyio +async def test_list_tools_empty() -> None: + """Test listing with empty results.""" + server = Server("test") + + @server.list_tools() + async def handle_list_tools() -> list[Tool]: + return [] + + handler = server.request_handlers[ListToolsRequest] + request = ListToolsRequest(method="tools/list", params=None) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListToolsResult) + assert result.root.tools == [] diff --git a/tests/server/lowlevel/test_server_pagination.py b/tests/server/lowlevel/test_server_pagination.py index f2c786e45..3a6b9f8ba 100644 --- a/tests/server/lowlevel/test_server_pagination.py +++ b/tests/server/lowlevel/test_server_pagination.py @@ -22,7 +22,7 @@ async def test_list_prompts_pagination() -> None: # Track what cursor was received received_cursor: Cursor | None = None - @server.list_prompts_paginated() + @server.list_prompts() async def handle_list_prompts(cursor: Cursor | None) -> ListPromptsResult: nonlocal received_cursor received_cursor = cursor @@ -51,7 +51,7 @@ async def test_list_resources_pagination() -> None: # Track what cursor was received received_cursor: Cursor | None = None - @server.list_resources_paginated() + @server.list_resources() async def handle_list_resources(cursor: Cursor | None) -> ListResourcesResult: nonlocal received_cursor received_cursor = cursor @@ -82,7 +82,7 @@ async def test_list_tools_pagination() -> None: # Track what cursor was received received_cursor: Cursor | None = None - @server.list_tools_paginated() + @server.list_tools() async def handle_list_tools(cursor: Cursor | None) -> ListToolsResult: nonlocal received_cursor received_cursor = cursor From a41f972c94d2155dad548cde677aa7ee8a2d622f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 1 Sep 2025 13:50:54 +0100 Subject: [PATCH 05/11] chore: clean up inspection code to remove redundant param inspection --- src/mcp/server/lowlevel/func_inspection.py | 6 ---- tests/server/lowlevel/test_func_inspection.py | 32 +++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py index 9573ee0ae..f69db4b95 100644 --- a/src/mcp/server/lowlevel/func_inspection.py +++ b/src/mcp/server/lowlevel/func_inspection.py @@ -19,12 +19,6 @@ def accepts_cursor(func: Callable[..., Any]) -> bool: params = dict(sig.parameters.items()) - method = inspect.ismethod(func) - - if method: - params.pop("self", None) - params.pop("cls", None) - if len(params) == 0: # No parameters at all - can't accept cursor return False diff --git a/tests/server/lowlevel/test_func_inspection.py b/tests/server/lowlevel/test_func_inspection.py index 4114b329d..cb00d9d78 100644 --- a/tests/server/lowlevel/test_func_inspection.py +++ b/tests/server/lowlevel/test_func_inspection.py @@ -13,20 +13,42 @@ async def no_cursor_method(self): """Instance method without cursor parameter""" pass + # noinspection PyMethodParameters + async def no_cursor_method_bad_self_name(bad): # pyright: ignore[reportSelfClsParameterName] + """Instance method with cursor parameter, but with bad self name""" + pass + async def cursor_method(self, cursor: types.Cursor | None): """Instance method with cursor parameter""" pass + # noinspection PyMethodParameters + async def cursor_method_bad_self_name(bad, cursor: types.Cursor | None): # pyright: ignore[reportSelfClsParameterName] + """Instance method with cursor parameter, but with bad self name""" + pass + @classmethod async def no_cursor_class_method(cls): """Class method without cursor parameter""" pass + # noinspection PyMethodParameters + @classmethod + async def no_cursor_class_method_bad_cls_name(bad): # pyright: ignore[reportSelfClsParameterName] + """Class method without cursor parameter, but with bad cls name""" + pass + @classmethod async def cursor_class_method(cls, cursor: types.Cursor | None): """Class method with cursor parameter""" pass + # noinspection PyMethodParameters + @classmethod + async def cursor_class_method_bad_cls_name(bad, cursor: types.Cursor | None): # pyright: ignore[reportSelfClsParameterName] + """Class method with cursor parameter, but with bad cls name""" + pass + @staticmethod async def no_cursor_static_method(): """Static method without cursor parameter""" @@ -37,6 +59,11 @@ async def cursor_static_method(cursor: types.Cursor | None): """Static method with cursor parameter""" pass + @staticmethod + async def cursor_static_method_bad_arg_name(self: types.Cursor | None): # pyright: ignore[reportSelfClsParameterName] + """Static method with cursor parameter, but the cursor argument is named self""" + pass + async def no_cursor_func(): """Function without cursor parameter""" @@ -108,13 +135,18 @@ async def mixed_positional_and_keyword(cursor: types.Cursor | None, *, extra: st (cursor_func_with_self, True, "function with param named 'self'"), # Instance methods (MyClass().no_cursor_method, False, "instance method without cursor"), + (MyClass().no_cursor_method_bad_self_name, False, "instance method without cursor (bad self name)"), (MyClass().cursor_method, True, "instance method with cursor"), + (MyClass().cursor_method_bad_self_name, True, "instance method with cursor (bad self name)"), # Class methods (MyClass.no_cursor_class_method, False, "class method without cursor"), + (MyClass.no_cursor_class_method_bad_cls_name, False, "class method without cursor (bad cls name)"), (MyClass.cursor_class_method, True, "class method with cursor"), + (MyClass.cursor_class_method_bad_cls_name, True, "class method with cursor (bad cls name)"), # Static methods (MyClass.no_cursor_static_method, False, "static method without cursor"), (MyClass.cursor_static_method, True, "static method with cursor"), + (MyClass.cursor_static_method_bad_arg_name, True, "static method with cursor (bad arg name)"), # Variadic parameters (var_positional_func, True, "function with *args"), (positional_with_var_positional_func, True, "function with cursor and *args"), From dd07224e1fdb76188370f268a90ecb0d47b17155 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 1 Sep 2025 14:58:22 +0100 Subject: [PATCH 06/11] feat: change to passing requests instead of cursors for pagination --- README.md | 5 +- .../mcp_simple_pagination/server.py | 9 +- .../snippets/servers/pagination_example.py | 5 +- src/mcp/server/lowlevel/func_inspection.py | 2 +- src/mcp/server/lowlevel/server.py | 44 +++---- tests/server/lowlevel/test_func_inspection.py | 122 +++++++++--------- .../server/lowlevel/test_server_pagination.py | 64 +++++---- 7 files changed, 134 insertions(+), 117 deletions(-) diff --git a/README.md b/README.md index 95c871a4e..e2ef5a7ca 100644 --- a/README.md +++ b/README.md @@ -1763,10 +1763,13 @@ ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items @server.list_resources() -async def list_resources_paginated(cursor: types.Cursor | None) -> types.ListResourcesResult: +async def list_resources_paginated(request: types.ListResourcesRequest) -> types.ListResourcesResult: """List resources with pagination support.""" page_size = 10 + # Extract cursor from request params + cursor = request.params.cursor if request.params is not None else None + # Parse cursor to get offset start = 0 if cursor is None else int(cursor) end = start + page_size diff --git a/examples/servers/simple-pagination/mcp_simple_pagination/server.py b/examples/servers/simple-pagination/mcp_simple_pagination/server.py index 97f545718..360cbc3cf 100644 --- a/examples/servers/simple-pagination/mcp_simple_pagination/server.py +++ b/examples/servers/simple-pagination/mcp_simple_pagination/server.py @@ -59,9 +59,10 @@ def main(port: int, transport: str) -> int: # Paginated list_tools - returns 5 tools per page @app.list_tools() - async def list_tools_paginated(cursor: types.Cursor | None) -> types.ListToolsResult: + async def list_tools_paginated(request: types.ListToolsRequest) -> types.ListToolsResult: page_size = 5 + cursor = request.params.cursor if request.params is not None else None if cursor is None: # First page start_idx = 0 @@ -86,10 +87,11 @@ async def list_tools_paginated(cursor: types.Cursor | None) -> types.ListToolsRe # Paginated list_resources - returns 10 resources per page @app.list_resources() async def list_resources_paginated( - cursor: types.Cursor | None, + request: types.ListResourcesRequest, ) -> types.ListResourcesResult: page_size = 10 + cursor = request.params.cursor if request.params is not None else None if cursor is None: # First page start_idx = 0 @@ -114,10 +116,11 @@ async def list_resources_paginated( # Paginated list_prompts - returns 7 prompts per page @app.list_prompts() async def list_prompts_paginated( - cursor: types.Cursor | None, + request: types.ListPromptsRequest, ) -> types.ListPromptsResult: page_size = 7 + cursor = request.params.cursor if request.params is not None else None if cursor is None: # First page start_idx = 0 diff --git a/examples/snippets/servers/pagination_example.py b/examples/snippets/servers/pagination_example.py index c8c99323c..70c3b3492 100644 --- a/examples/snippets/servers/pagination_example.py +++ b/examples/snippets/servers/pagination_example.py @@ -15,10 +15,13 @@ @server.list_resources() -async def list_resources_paginated(cursor: types.Cursor | None) -> types.ListResourcesResult: +async def list_resources_paginated(request: types.ListResourcesRequest) -> types.ListResourcesResult: """List resources with pagination support.""" page_size = 10 + # Extract cursor from request params + cursor = request.params.cursor if request.params is not None else None + # Parse cursor to get offset start = 0 if cursor is None else int(cursor) end = start + page_size diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py index f69db4b95..7f6e18860 100644 --- a/src/mcp/server/lowlevel/func_inspection.py +++ b/src/mcp/server/lowlevel/func_inspection.py @@ -3,7 +3,7 @@ from typing import Any -def accepts_cursor(func: Callable[..., Any]) -> bool: +def accepts_request(func: Callable[..., Any]) -> bool: """ True if the function accepts a cursor parameter call, otherwise false. diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index fe36f28f1..4ce28401a 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -82,7 +82,7 @@ async def main(): from typing_extensions import TypeVar import mcp.types as types -from mcp.server.lowlevel.func_inspection import accepts_cursor +from mcp.server.lowlevel.func_inspection import accepts_request from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -232,19 +232,19 @@ def request_context( def list_prompts(self): def decorator( func: Callable[[], Awaitable[list[types.Prompt]]] - | Callable[[types.Cursor | None], Awaitable[types.ListPromptsResult]], + | Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], ): logger.debug("Registering handler for PromptListRequest") - pass_cursor = accepts_cursor(func) + pass_request = accepts_request(func) - if pass_cursor: - cursor_func = cast(Callable[[types.Cursor | None], Awaitable[types.ListPromptsResult]], func) + if pass_request: + request_func = cast(Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], func) - async def cursor_handler(req: types.ListPromptsRequest): - result = await cursor_func(req.params.cursor if req.params is not None else None) + async def request_handler(req: types.ListPromptsRequest): + result = await request_func(req) return types.ServerResult(result) - handler = cursor_handler + handler = request_handler else: list_func = cast(Callable[[], Awaitable[list[types.Prompt]]], func) @@ -277,19 +277,19 @@ async def handler(req: types.GetPromptRequest): def list_resources(self): def decorator( func: Callable[[], Awaitable[list[types.Resource]]] - | Callable[[types.Cursor | None], Awaitable[types.ListResourcesResult]], + | Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], ): logger.debug("Registering handler for ListResourcesRequest") - pass_cursor = accepts_cursor(func) + pass_request = accepts_request(func) - if pass_cursor: - cursor_func = cast(Callable[[types.Cursor | None], Awaitable[types.ListResourcesResult]], func) + if pass_request: + request_func = cast(Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], func) - async def cursor_handler(req: types.ListResourcesRequest): - result = await cursor_func(req.params.cursor if req.params is not None else None) + async def request_handler(req: types.ListResourcesRequest): + result = await request_func(req) return types.ServerResult(result) - handler = cursor_handler + handler = request_handler else: list_func = cast(Callable[[], Awaitable[list[types.Resource]]], func) @@ -417,22 +417,22 @@ async def handler(req: types.UnsubscribeRequest): def list_tools(self): def decorator( func: Callable[[], Awaitable[list[types.Tool]]] - | Callable[[types.Cursor | None], Awaitable[types.ListToolsResult]], + | Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], ): logger.debug("Registering handler for ListToolsRequest") - pass_cursor = accepts_cursor(func) + pass_request = accepts_request(func) - if pass_cursor: - cursor_func = cast(Callable[[types.Cursor | None], Awaitable[types.ListToolsResult]], func) + if pass_request: + request_func = cast(Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], func) - async def cursor_handler(req: types.ListToolsRequest): - result = await cursor_func(req.params.cursor if req.params is not None else None) + async def request_handler(req: types.ListToolsRequest): + result = await request_func(req) # Refresh the tool cache with returned tools for tool in result.tools: self._tool_cache[tool.name] = tool return types.ServerResult(result) - handler = cursor_handler + handler = request_handler else: list_func = cast(Callable[[], Awaitable[list[types.Tool]]], func) diff --git a/tests/server/lowlevel/test_func_inspection.py b/tests/server/lowlevel/test_func_inspection.py index cb00d9d78..674675d8d 100644 --- a/tests/server/lowlevel/test_func_inspection.py +++ b/tests/server/lowlevel/test_func_inspection.py @@ -4,83 +4,83 @@ import pytest from mcp import types -from mcp.server.lowlevel.func_inspection import accepts_cursor +from mcp.server.lowlevel.func_inspection import accepts_request # Test fixtures - functions and methods with various signatures class MyClass: - async def no_cursor_method(self): - """Instance method without cursor parameter""" + async def no_request_method(self): + """Instance method without request parameter""" pass # noinspection PyMethodParameters - async def no_cursor_method_bad_self_name(bad): # pyright: ignore[reportSelfClsParameterName] - """Instance method with cursor parameter, but with bad self name""" + async def no_request_method_bad_self_name(bad): # pyright: ignore[reportSelfClsParameterName] + """Instance method without request parameter, but with bad self name""" pass - async def cursor_method(self, cursor: types.Cursor | None): - """Instance method with cursor parameter""" + async def request_method(self, request: types.ListPromptsRequest): + """Instance method with request parameter""" pass # noinspection PyMethodParameters - async def cursor_method_bad_self_name(bad, cursor: types.Cursor | None): # pyright: ignore[reportSelfClsParameterName] - """Instance method with cursor parameter, but with bad self name""" + async def request_method_bad_self_name(bad, request: types.ListPromptsRequest): # pyright: ignore[reportSelfClsParameterName] + """Instance method with request parameter, but with bad self name""" pass @classmethod - async def no_cursor_class_method(cls): - """Class method without cursor parameter""" + async def no_request_class_method(cls): + """Class method without request parameter""" pass # noinspection PyMethodParameters @classmethod - async def no_cursor_class_method_bad_cls_name(bad): # pyright: ignore[reportSelfClsParameterName] - """Class method without cursor parameter, but with bad cls name""" + async def no_request_class_method_bad_cls_name(bad): # pyright: ignore[reportSelfClsParameterName] + """Class method without request parameter, but with bad cls name""" pass @classmethod - async def cursor_class_method(cls, cursor: types.Cursor | None): - """Class method with cursor parameter""" + async def request_class_method(cls, request: types.ListPromptsRequest): + """Class method with request parameter""" pass # noinspection PyMethodParameters @classmethod - async def cursor_class_method_bad_cls_name(bad, cursor: types.Cursor | None): # pyright: ignore[reportSelfClsParameterName] - """Class method with cursor parameter, but with bad cls name""" + async def request_class_method_bad_cls_name(bad, request: types.ListPromptsRequest): # pyright: ignore[reportSelfClsParameterName] + """Class method with request parameter, but with bad cls name""" pass @staticmethod - async def no_cursor_static_method(): - """Static method without cursor parameter""" + async def no_request_static_method(): + """Static method without request parameter""" pass @staticmethod - async def cursor_static_method(cursor: types.Cursor | None): - """Static method with cursor parameter""" + async def request_static_method(request: types.ListPromptsRequest): + """Static method with request parameter""" pass @staticmethod - async def cursor_static_method_bad_arg_name(self: types.Cursor | None): # pyright: ignore[reportSelfClsParameterName] - """Static method with cursor parameter, but the cursor argument is named self""" + async def request_static_method_bad_arg_name(self: types.ListPromptsRequest): # pyright: ignore[reportSelfClsParameterName] + """Static method with request parameter, but the request argument is named self""" pass -async def no_cursor_func(): - """Function without cursor parameter""" +async def no_request_func(): + """Function without request parameter""" pass -async def cursor_func(cursor: types.Cursor | None): - """Function with cursor parameter""" +async def request_func(request: types.ListPromptsRequest): + """Function with request parameter""" pass -async def cursor_func_different_name(c: types.Cursor | None): - """Function with cursor parameter but different arg name""" +async def request_func_different_name(req: types.ListPromptsRequest): + """Function with request parameter but different arg name""" pass -async def cursor_func_with_self(self: types.Cursor | None): +async def request_func_with_self(self: types.ListPromptsRequest): """Function with parameter named 'self' (edge case)""" pass @@ -90,8 +90,8 @@ async def var_positional_func(*args: Any): pass -async def positional_with_var_positional_func(cursor: types.Cursor | None, *args: Any): - """Function with cursor and *args""" +async def positional_with_var_positional_func(request: types.ListPromptsRequest, *args: Any): + """Function with request and *args""" pass @@ -100,18 +100,18 @@ async def var_keyword_func(**kwargs: Any): pass -async def cursor_with_var_keyword_func(cursor: types.Cursor | None, **kwargs: Any): - """Function with cursor and **kwargs""" +async def request_with_var_keyword_func(request: types.ListPromptsRequest, **kwargs: Any): + """Function with request and **kwargs""" pass -async def cursor_with_default(cursor: types.Cursor | None = None): - """Function with cursor parameter having default value""" +async def request_with_default(request: types.ListPromptsRequest | None = None): + """Function with request parameter having default value""" pass -async def keyword_only_with_defaults(*, cursor: types.Cursor | None = None): - """Function with keyword-only cursor with default""" +async def keyword_only_with_defaults(*, request: types.ListPromptsRequest | None = None): + """Function with keyword-only request with default""" pass @@ -120,7 +120,7 @@ async def keyword_only_multiple_all_defaults(*, a: str = "test", b: int = 42): pass -async def mixed_positional_and_keyword(cursor: types.Cursor | None, *, extra: str = "test"): +async def mixed_positional_and_keyword(request: types.ListPromptsRequest, *, extra: str = "test"): """Function with positional and keyword-only params""" pass @@ -129,31 +129,31 @@ async def mixed_positional_and_keyword(cursor: types.Cursor | None, *, extra: st "callable_obj,expected,description", [ # Regular functions - (no_cursor_func, False, "function without parameters"), - (cursor_func, True, "function with cursor parameter"), - (cursor_func_different_name, True, "function with cursor (different param name)"), - (cursor_func_with_self, True, "function with param named 'self'"), + (no_request_func, False, "function without parameters"), + (request_func, True, "function with request parameter"), + (request_func_different_name, True, "function with request (different param name)"), + (request_func_with_self, True, "function with param named 'self'"), # Instance methods - (MyClass().no_cursor_method, False, "instance method without cursor"), - (MyClass().no_cursor_method_bad_self_name, False, "instance method without cursor (bad self name)"), - (MyClass().cursor_method, True, "instance method with cursor"), - (MyClass().cursor_method_bad_self_name, True, "instance method with cursor (bad self name)"), + (MyClass().no_request_method, False, "instance method without request"), + (MyClass().no_request_method_bad_self_name, False, "instance method without request (bad self name)"), + (MyClass().request_method, True, "instance method with request"), + (MyClass().request_method_bad_self_name, True, "instance method with request (bad self name)"), # Class methods - (MyClass.no_cursor_class_method, False, "class method without cursor"), - (MyClass.no_cursor_class_method_bad_cls_name, False, "class method without cursor (bad cls name)"), - (MyClass.cursor_class_method, True, "class method with cursor"), - (MyClass.cursor_class_method_bad_cls_name, True, "class method with cursor (bad cls name)"), + (MyClass.no_request_class_method, False, "class method without request"), + (MyClass.no_request_class_method_bad_cls_name, False, "class method without request (bad cls name)"), + (MyClass.request_class_method, True, "class method with request"), + (MyClass.request_class_method_bad_cls_name, True, "class method with request (bad cls name)"), # Static methods - (MyClass.no_cursor_static_method, False, "static method without cursor"), - (MyClass.cursor_static_method, True, "static method with cursor"), - (MyClass.cursor_static_method_bad_arg_name, True, "static method with cursor (bad arg name)"), + (MyClass.no_request_static_method, False, "static method without request"), + (MyClass.request_static_method, True, "static method with request"), + (MyClass.request_static_method_bad_arg_name, True, "static method with request (bad arg name)"), # Variadic parameters (var_positional_func, True, "function with *args"), - (positional_with_var_positional_func, True, "function with cursor and *args"), + (positional_with_var_positional_func, True, "function with request and *args"), (var_keyword_func, False, "function with **kwargs"), - (cursor_with_var_keyword_func, True, "function with cursor and **kwargs"), + (request_with_var_keyword_func, True, "function with request and **kwargs"), # Edge cases - (cursor_with_default, True, "function with cursor having default value"), + (request_with_default, True, "function with request having default value"), # Keyword-only parameters (keyword_only_with_defaults, False, "keyword-only with default (can call with no args)"), (keyword_only_multiple_all_defaults, False, "multiple keyword-only all with defaults"), @@ -161,13 +161,13 @@ async def mixed_positional_and_keyword(cursor: types.Cursor | None, *, extra: st ], ids=lambda x: x if isinstance(x, str) else "", ) -def test_accepts_cursor(callable_obj: Callable[..., Any], expected: bool, description: str): - """Test that accepts_cursor correctly identifies functions that accept a cursor parameter. +def test_accepts_request(callable_obj: Callable[..., Any], expected: bool, description: str): + """Test that accepts_request correctly identifies functions that accept a request parameter. The function should return True if the callable can potentially accept a positional - cursor argument. Returns False if: + request argument. Returns False if: - No parameters at all - Only keyword-only parameters that ALL have defaults (can call with no args) - Only **kwargs parameter (can't accept positional arguments) """ - assert accepts_cursor(callable_obj) == expected, f"Failed for {description}" + assert accepts_request(callable_obj) == expected, f"Failed for {description}" diff --git a/tests/server/lowlevel/test_server_pagination.py b/tests/server/lowlevel/test_server_pagination.py index 3a6b9f8ba..8d64dd525 100644 --- a/tests/server/lowlevel/test_server_pagination.py +++ b/tests/server/lowlevel/test_server_pagination.py @@ -2,7 +2,6 @@ from mcp.server import Server from mcp.types import ( - Cursor, ListPromptsRequest, ListPromptsResult, ListResourcesRequest, @@ -19,27 +18,30 @@ async def test_list_prompts_pagination() -> None: server = Server("test") test_cursor = "test-cursor-123" - # Track what cursor was received - received_cursor: Cursor | None = None + # Track what request was received + received_request: ListPromptsRequest | None = None @server.list_prompts() - async def handle_list_prompts(cursor: Cursor | None) -> ListPromptsResult: - nonlocal received_cursor - received_cursor = cursor + async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: + nonlocal received_request + received_request = request return ListPromptsResult(prompts=[], nextCursor="next") handler = server.request_handlers[ListPromptsRequest] - # Test: No cursor provided -> handler receives None + # Test: No cursor provided -> handler receives request with None params request = ListPromptsRequest(method="prompts/list", params=None) result = await handler(request) - assert received_cursor is None + assert received_request is not None + assert received_request.params is None assert isinstance(result, ServerResult) - # Test: Cursor provided -> handler receives exact cursor value + # Test: Cursor provided -> handler receives request with cursor in params request_with_cursor = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor=test_cursor)) result2 = await handler(request_with_cursor) - assert received_cursor == test_cursor + assert received_request is not None + assert received_request.params is not None + assert received_request.params.cursor == test_cursor assert isinstance(result2, ServerResult) @@ -48,29 +50,32 @@ async def test_list_resources_pagination() -> None: server = Server("test") test_cursor = "resource-cursor-456" - # Track what cursor was received - received_cursor: Cursor | None = None + # Track what request was received + received_request: ListResourcesRequest | None = None @server.list_resources() - async def handle_list_resources(cursor: Cursor | None) -> ListResourcesResult: - nonlocal received_cursor - received_cursor = cursor + async def handle_list_resources(request: ListResourcesRequest) -> ListResourcesResult: + nonlocal received_request + received_request = request return ListResourcesResult(resources=[], nextCursor="next") handler = server.request_handlers[ListResourcesRequest] - # Test: No cursor provided -> handler receives None + # Test: No cursor provided -> handler receives request with None params request = ListResourcesRequest(method="resources/list", params=None) result = await handler(request) - assert received_cursor is None + assert received_request is not None + assert received_request.params is None assert isinstance(result, ServerResult) - # Test: Cursor provided -> handler receives exact cursor value + # Test: Cursor provided -> handler receives request with cursor in params request_with_cursor = ListResourcesRequest( method="resources/list", params=PaginatedRequestParams(cursor=test_cursor) ) result2 = await handler(request_with_cursor) - assert received_cursor == test_cursor + assert received_request is not None + assert received_request.params is not None + assert received_request.params.cursor == test_cursor assert isinstance(result2, ServerResult) @@ -79,25 +84,28 @@ async def test_list_tools_pagination() -> None: server = Server("test") test_cursor = "tools-cursor-789" - # Track what cursor was received - received_cursor: Cursor | None = None + # Track what request was received + received_request: ListToolsRequest | None = None @server.list_tools() - async def handle_list_tools(cursor: Cursor | None) -> ListToolsResult: - nonlocal received_cursor - received_cursor = cursor + async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: + nonlocal received_request + received_request = request return ListToolsResult(tools=[], nextCursor="next") handler = server.request_handlers[ListToolsRequest] - # Test: No cursor provided -> handler receives None + # Test: No cursor provided -> handler receives request with None params request = ListToolsRequest(method="tools/list", params=None) result = await handler(request) - assert received_cursor is None + assert received_request is not None + assert received_request.params is None assert isinstance(result, ServerResult) - # Test: Cursor provided -> handler receives exact cursor value + # Test: Cursor provided -> handler receives request with cursor in params request_with_cursor = ListToolsRequest(method="tools/list", params=PaginatedRequestParams(cursor=test_cursor)) result2 = await handler(request_with_cursor) - assert received_cursor == test_cursor + assert received_request is not None + assert received_request.params is not None + assert received_request.params.cursor == test_cursor assert isinstance(result2, ServerResult) From a38351fe140dfa5e52a3ad9d0e344695360d2a8d Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 8 Sep 2025 13:41:21 -0700 Subject: [PATCH 07/11] fix: ruff error on unit test --- tests/server/lowlevel/test_func_inspection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/server/lowlevel/test_func_inspection.py b/tests/server/lowlevel/test_func_inspection.py index 674675d8d..8499a3e33 100644 --- a/tests/server/lowlevel/test_func_inspection.py +++ b/tests/server/lowlevel/test_func_inspection.py @@ -60,7 +60,7 @@ async def request_static_method(request: types.ListPromptsRequest): pass @staticmethod - async def request_static_method_bad_arg_name(self: types.ListPromptsRequest): # pyright: ignore[reportSelfClsParameterName] + async def request_static_method_bad_arg_name(self: types.ListPromptsRequest): # pyright: ignore[reportSelfClsParameterName] # noqa: PLW0211 """Static method with request parameter, but the request argument is named self""" pass From a42e0973fadddc23a0df0461f1b2535045774966 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 15 Sep 2025 15:07:56 +0100 Subject: [PATCH 08/11] chore: rename and clarify function inspection code --- src/mcp/server/lowlevel/func_inspection.py | 25 ++++++++++--------- src/mcp/server/lowlevel/server.py | 8 +++--- tests/server/lowlevel/test_func_inspection.py | 13 ++++++---- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py index 7f6e18860..3a6a40f71 100644 --- a/src/mcp/server/lowlevel/func_inspection.py +++ b/src/mcp/server/lowlevel/func_inspection.py @@ -3,14 +3,13 @@ from typing import Any -def accepts_request(func: Callable[..., Any]) -> bool: +def accepts_single_positional_arg(func: Callable[..., Any]) -> bool: """ - True if the function accepts a cursor parameter call, otherwise false. + True if the function accepts at least one positional argument, otherwise false. - `accepts_cursor` does not validate that the function will work. For - example, if `func` contains keyword-only arguments with no defaults, - then it will not work when used in the `lowlevel/server.py` code, but - this function will not raise an exception. + This function intentionally does not define behavior for `func`s that + contain more than one positional argument, or any required keyword + arguments without defaults. """ try: sig = inspect.signature(func) @@ -20,7 +19,7 @@ def accepts_request(func: Callable[..., Any]) -> bool: params = dict(sig.parameters.items()) if len(params) == 0: - # No parameters at all - can't accept cursor + # No parameters at all - can't accept single argument return False # Check if ALL remaining parameters are keyword-only @@ -28,18 +27,20 @@ def accepts_request(func: Callable[..., Any]) -> bool: if all_keyword_only: # If all params are keyword-only, check if they ALL have defaults - # If they do, the function can be called with no arguments -> no cursor + # If they do, the function can be called with no arguments -> no argument all_have_defaults = all(param.default is not inspect.Parameter.empty for param in params.values()) - return not all_have_defaults # False if all have defaults (no cursor), True otherwise + if all_have_defaults: + return False + # otherwise, undefined (doesn't accept a positional argument, and requires at least one keyword only) # Check if the ONLY parameter is **kwargs (VAR_KEYWORD) - # A function with only **kwargs can't accept a positional cursor argument + # A function with only **kwargs can't accept a positional argument if len(params) == 1: only_param = next(iter(params.values())) if only_param.kind == inspect.Parameter.VAR_KEYWORD: - return False # Can't pass positional cursor to **kwargs + return False # Can't pass positional argument to **kwargs - # Has at least one positional or variadic parameter - can accept cursor + # Has at least one positional or variadic parameter - can accept argument # Important note: this is designed to _not_ handle the situation where # there are multiple keyword only arguments with no defaults. In those # situations it's an invalid handler function, and will error. But it's diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 4ce28401a..2d73aac35 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -82,7 +82,7 @@ async def main(): from typing_extensions import TypeVar import mcp.types as types -from mcp.server.lowlevel.func_inspection import accepts_request +from mcp.server.lowlevel.func_inspection import accepts_single_positional_arg from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -235,7 +235,7 @@ def decorator( | Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], ): logger.debug("Registering handler for PromptListRequest") - pass_request = accepts_request(func) + pass_request = accepts_single_positional_arg(func) if pass_request: request_func = cast(Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], func) @@ -280,7 +280,7 @@ def decorator( | Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], ): logger.debug("Registering handler for ListResourcesRequest") - pass_request = accepts_request(func) + pass_request = accepts_single_positional_arg(func) if pass_request: request_func = cast(Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], func) @@ -420,7 +420,7 @@ def decorator( | Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], ): logger.debug("Registering handler for ListToolsRequest") - pass_request = accepts_request(func) + pass_request = accepts_single_positional_arg(func) if pass_request: request_func = cast(Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], func) diff --git a/tests/server/lowlevel/test_func_inspection.py b/tests/server/lowlevel/test_func_inspection.py index 8499a3e33..d05538e52 100644 --- a/tests/server/lowlevel/test_func_inspection.py +++ b/tests/server/lowlevel/test_func_inspection.py @@ -4,10 +4,9 @@ import pytest from mcp import types -from mcp.server.lowlevel.func_inspection import accepts_request +from mcp.server.lowlevel.func_inspection import accepts_single_positional_arg -# Test fixtures - functions and methods with various signatures class MyClass: async def no_request_method(self): """Instance method without request parameter""" @@ -161,8 +160,12 @@ async def mixed_positional_and_keyword(request: types.ListPromptsRequest, *, ext ], ids=lambda x: x if isinstance(x, str) else "", ) -def test_accepts_request(callable_obj: Callable[..., Any], expected: bool, description: str): - """Test that accepts_request correctly identifies functions that accept a request parameter. +def test_accepts_single_positional_arg(callable_obj: Callable[..., Any], expected: bool, description: str): + """Test that `accepts_single_positional_arg` correctly identifies functions that accept a single argument. + + `accepts_single_positional_arg` is currently only used in the case of + the lowlevel server code checking whether a handler accepts a request + argument, so the test cases reference a "request" param/arg. The function should return True if the callable can potentially accept a positional request argument. Returns False if: @@ -170,4 +173,4 @@ def test_accepts_request(callable_obj: Callable[..., Any], expected: bool, descr - Only keyword-only parameters that ALL have defaults (can call with no args) - Only **kwargs parameter (can't accept positional arguments) """ - assert accepts_request(callable_obj) == expected, f"Failed for {description}" + assert accepts_single_positional_arg(callable_obj) == expected, f"Failed for {description}" From 2a592d24432fd648e0e7df7acdd57aa2973ef3e4 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 16 Sep 2025 19:15:39 +0100 Subject: [PATCH 09/11] feature: add type checking for passing request object --- src/mcp/server/lowlevel/func_inspection.py | 170 ++++++++++++++- src/mcp/server/lowlevel/server.py | 17 +- .../lowlevel/test_advanced_type_inspection.py | 193 +++++++++++++++++ .../lowlevel/test_deprecation_warnings.py | 194 ++++++++++++++++++ tests/server/lowlevel/test_server_listing.py | 56 +++-- .../lowlevel/test_type_accepts_request.py | 88 ++++++++ 6 files changed, 695 insertions(+), 23 deletions(-) create mode 100644 tests/server/lowlevel/test_advanced_type_inspection.py create mode 100644 tests/server/lowlevel/test_deprecation_warnings.py create mode 100644 tests/server/lowlevel/test_type_accepts_request.py diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py index 3a6a40f71..1e6d862fa 100644 --- a/src/mcp/server/lowlevel/func_inspection.py +++ b/src/mcp/server/lowlevel/func_inspection.py @@ -1,6 +1,8 @@ import inspect +import types +import warnings from collections.abc import Callable -from typing import Any +from typing import Any, TypeVar, Union, get_args, get_origin def accepts_single_positional_arg(func: Callable[..., Any]) -> bool: @@ -47,3 +49,169 @@ def accepts_single_positional_arg(func: Callable[..., Any]) -> bool: # not the responsibility of this function to check the validity of a # callback. return True + + +def get_first_parameter_type(func: Callable[..., Any]) -> Any: + """ + Get the type annotation of the first parameter of a function. + + Returns None if: + - The function has no parameters + - The first parameter has no type annotation + - The signature cannot be inspected + + Returns the actual annotation otherwise (could be a type, Any, Union, TypeVar, etc.) + """ + try: + sig = inspect.signature(func) + except (ValueError, TypeError): + return None + + params = list(sig.parameters.values()) + if not params: + return None + + first_param = params[0] + + # Skip *args and **kwargs + if first_param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + return None + + annotation = first_param.annotation + if annotation == inspect.Parameter.empty: + return None + + return annotation + + +def type_accepts_request(param_type: Any, request_type: type) -> bool: + """ + Check if a parameter type annotation can accept the request type. + + Handles: + - Exact type match + - Union types (checks if request_type is in the Union) + - TypeVars (checks if request_type matches the bound or constraints) + - Generic types (basic support) + - Any (always returns True) + + Returns False for None or incompatible types. + """ + if param_type is None: + return False + + # Check for Any type + if param_type is Any: + return True + + # Exact match + if param_type == request_type: + return True + + # Handle Union types (both typing.Union and | syntax) + origin = get_origin(param_type) + if origin is Union or origin is types.UnionType: + args = get_args(param_type) + # Check if request_type is in the Union + for arg in args: + if arg == request_type: + return True + # Recursively check each union member + if type_accepts_request(arg, request_type): + return True + return False + + # Handle TypeVar + if isinstance(param_type, TypeVar): + # Check if request_type matches the bound + if param_type.__bound__ is not None: + if request_type == param_type.__bound__: + return True + # Check if request_type is a subclass of the bound + try: + if issubclass(request_type, param_type.__bound__): + return True + except TypeError: + pass + + # Check constraints + if param_type.__constraints__: + for constraint in param_type.__constraints__: + if request_type == constraint: + return True + try: + if issubclass(request_type, constraint): + return True + except TypeError: + pass + + return False + + # For other generic types, check if request_type matches the origin + if origin is not None: + # Get the base generic type (e.g., list from list[str]) + return request_type == origin + + return False + + +def should_pass_request(func: Callable[..., Any], request_type: type) -> tuple[bool, bool]: + """ + Determine if a request should be passed to the function based on parameter type inspection. + + Returns a tuple of (should_pass_request, should_deprecate): + - should_pass_request: True if the request should be passed to the function + - should_deprecate: True if a deprecation warning should be issued + + The decision logic: + 1. If the function has no parameters -> (False, True) - old style without params, deprecate + 2. If the function has parameters but can't accept positional args -> (False, False) + 3. If the first parameter type accepts the request type -> (True, False) - pass request, no deprecation + 4. If the first parameter is typed as Any -> (True, True) - pass request but deprecate (effectively untyped) + 5. If the first parameter is typed with something incompatible -> (False, True) - old style, deprecate + 6. If the first parameter is untyped but accepts positional args -> (True, True) - pass request, deprecate + """ + can_accept_arg = accepts_single_positional_arg(func) + + if not can_accept_arg: + # Check if it has no parameters at all (old style) + try: + sig = inspect.signature(func) + if len(sig.parameters) == 0: + # Old style handler with no parameters - don't pass request but deprecate + return False, True + except (ValueError, TypeError): + pass + # Can't accept positional arguments for other reasons + return False, False + + param_type = get_first_parameter_type(func) + + if param_type is None: + # Untyped parameter - this is the old style, pass request but deprecate + return True, True + + # Check if the parameter type can accept the request + if type_accepts_request(param_type, request_type): + # Check if it's Any - if so, we should deprecate + if param_type is Any: + return True, True + # Properly typed to accept the request - pass request, no deprecation + return True, False + + # Parameter is typed with something incompatible - this is an old style handler expecting + # a different signature, don't pass request, issue deprecation + return False, True + + +def issue_deprecation_warning(func: Callable[..., Any], request_type: type) -> None: + """ + Issue a deprecation warning for handlers that don't use the new request parameter style. + """ + func_name = getattr(func, "__name__", str(func)) + warnings.warn( + f"Handler '{func_name}' should accept a '{request_type.__name__}' parameter. " + "Support for handlers without this parameter will be removed in a future version.", + DeprecationWarning, + stacklevel=4, + ) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 2d73aac35..ac61849bc 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -82,7 +82,7 @@ async def main(): from typing_extensions import TypeVar import mcp.types as types -from mcp.server.lowlevel.func_inspection import accepts_single_positional_arg +from mcp.server.lowlevel.func_inspection import issue_deprecation_warning, should_pass_request from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -235,7 +235,10 @@ def decorator( | Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], ): logger.debug("Registering handler for PromptListRequest") - pass_request = accepts_single_positional_arg(func) + pass_request, should_deprecate = should_pass_request(func, types.ListPromptsRequest) + + if should_deprecate: + issue_deprecation_warning(func, types.ListPromptsRequest) if pass_request: request_func = cast(Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], func) @@ -280,7 +283,10 @@ def decorator( | Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], ): logger.debug("Registering handler for ListResourcesRequest") - pass_request = accepts_single_positional_arg(func) + pass_request, should_deprecate = should_pass_request(func, types.ListResourcesRequest) + + if should_deprecate: + issue_deprecation_warning(func, types.ListResourcesRequest) if pass_request: request_func = cast(Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], func) @@ -420,7 +426,10 @@ def decorator( | Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], ): logger.debug("Registering handler for ListToolsRequest") - pass_request = accepts_single_positional_arg(func) + pass_request, should_deprecate = should_pass_request(func, types.ListToolsRequest) + + if should_deprecate: + issue_deprecation_warning(func, types.ListToolsRequest) if pass_request: request_func = cast(Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], func) diff --git a/tests/server/lowlevel/test_advanced_type_inspection.py b/tests/server/lowlevel/test_advanced_type_inspection.py new file mode 100644 index 000000000..0dcd27941 --- /dev/null +++ b/tests/server/lowlevel/test_advanced_type_inspection.py @@ -0,0 +1,193 @@ +"""Tests for advanced type inspection features in pagination handlers.""" + +import warnings +from typing import Any, TypeVar + +import pytest + +from mcp.server import Server +from mcp.types import ( + ListPromptsRequest, + ListPromptsResult, + ListToolsRequest, + Prompt, + ServerResult, +) + +# Define TypeVars for testing +T = TypeVar("T") +ConstrainedRequest = TypeVar("ConstrainedRequest", ListPromptsRequest, ListToolsRequest) + + +@pytest.mark.anyio +async def test_union_type_with_request_no_warning() -> None: + """Test that Union types containing the request type don't trigger warnings.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_prompts() + async def handle_list_prompts(request: ListPromptsRequest | None) -> ListPromptsResult: + assert request is not None + return ListPromptsResult(prompts=[]) + + # No deprecation warning should be issued for Union containing request type + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) == 0 + + +@pytest.mark.anyio +async def test_union_type_multiple_requests_no_warning() -> None: + """Test Union with multiple request types works correctly.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_prompts() + async def handle_list_prompts( + request: ListPromptsRequest | ListToolsRequest, + ) -> ListPromptsResult: + assert isinstance(request, ListPromptsRequest) + return ListPromptsResult(prompts=[]) + + # No deprecation warning - Union contains the request type + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) == 0 + + +@pytest.mark.anyio +async def test_any_type_triggers_warning() -> None: + """Test that Any type triggers deprecation warning.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_prompts() + async def handle_list_prompts(request: Any) -> ListPromptsResult: + return ListPromptsResult(prompts=[]) + + # Deprecation warning should be issued for Any type (effectively untyped) + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) == 1 + assert "ListPromptsRequest" in str(deprecation_warnings[0].message) + + +@pytest.mark.anyio +async def test_typevar_with_bound_no_warning() -> None: + """Test that TypeVar with matching bound doesn't trigger warning.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + bound_request = TypeVar("bound_request", bound=ListPromptsRequest) + + @server.list_prompts() + async def handle_list_prompts(request: bound_request) -> ListPromptsResult: # type: ignore[reportInvalidTypeVarUse] + return ListPromptsResult(prompts=[]) + + # No warning - TypeVar is bound to ListPromptsRequest + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) == 0 + + +@pytest.mark.anyio +async def test_typevar_with_constraints_no_warning() -> None: + """Test that TypeVar with matching constraint doesn't trigger warning.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_prompts() + async def handle_list_prompts(request: ConstrainedRequest) -> ListPromptsResult: + return ListPromptsResult(prompts=[]) + + # No warning - TypeVar has ListPromptsRequest as a constraint + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) == 0 + + +@pytest.mark.anyio +async def test_any_type_still_receives_request() -> None: + """Test that handlers with Any type still receive the request object.""" + server = Server("test") + received_request: Any = None + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + @server.list_prompts() + async def handle_list_prompts(request: Any) -> ListPromptsResult: + nonlocal received_request + received_request = request + return ListPromptsResult(prompts=[]) + + handler = server.request_handlers[ListPromptsRequest] + request = ListPromptsRequest(method="prompts/list", params=None) + result = await handler(request) + + assert received_request is not None + assert isinstance(received_request, ListPromptsRequest) + assert isinstance(result, ServerResult) + + +@pytest.mark.anyio +async def test_union_handler_receives_correct_request() -> None: + """Test that Union-typed handlers receive the request correctly.""" + server = Server("test") + received_request: ListPromptsRequest | None = None + + @server.list_prompts() + async def handle_list_prompts(request: ListPromptsRequest | None) -> ListPromptsResult: + nonlocal received_request + received_request = request + return ListPromptsResult(prompts=[Prompt(name="test")]) + + handler = server.request_handlers[ListPromptsRequest] + request = ListPromptsRequest(method="prompts/list", params=None) + result = await handler(request) + + assert received_request is not None + assert isinstance(received_request, ListPromptsRequest) + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListPromptsResult) + assert len(result.root.prompts) == 1 + + +@pytest.mark.anyio +async def test_wrong_union_type_triggers_warning() -> None: + """Test that Union without the request type triggers deprecation warning.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_prompts() # type: ignore[arg-type] # Intentionally testing incorrect type for deprecation warning + async def handle_list_prompts(request: str | int) -> list[Prompt]: + return [] + + # Deprecation warning should be issued - Union doesn't contain request type + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) == 1 + + +@pytest.mark.anyio +async def test_generic_typevar_no_warning() -> None: + """Test that generic TypeVar doesn't trigger warning.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_prompts() + async def handle_list_prompts(request: T) -> ListPromptsResult: # type: ignore[valid-type] + return ListPromptsResult(prompts=[]) + + # Generic TypeVar without bounds - should not trigger warning but will receive request + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + # This may or may not warn depending on implementation - the key is it shouldn't break + assert len(deprecation_warnings) in [0, 1] # Either way is acceptable diff --git a/tests/server/lowlevel/test_deprecation_warnings.py b/tests/server/lowlevel/test_deprecation_warnings.py new file mode 100644 index 000000000..08346c9a3 --- /dev/null +++ b/tests/server/lowlevel/test_deprecation_warnings.py @@ -0,0 +1,194 @@ +import warnings + +import pytest + +from mcp.server import Server +from mcp.types import ( + ListPromptsRequest, + ListPromptsResult, + ListResourcesRequest, + ListResourcesResult, + ListToolsRequest, + ListToolsResult, + PaginatedRequestParams, + Prompt, + Resource, + ServerResult, + Tool, +) + + +@pytest.mark.anyio +async def test_list_prompts_with_typed_request_no_warning() -> None: + """Test that properly typed handlers don't trigger deprecation warnings.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_prompts() + async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: + return ListPromptsResult(prompts=[]) + + # No deprecation warning should be issued + assert len([warning for warning in w if issubclass(warning.category, DeprecationWarning)]) == 0 + + +@pytest.mark.anyio +async def test_list_prompts_without_params_triggers_warning() -> None: + """Test that handlers without parameters trigger deprecation warnings.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_prompts() + async def handle_list_prompts() -> list[Prompt]: + return [] + + # A deprecation warning should be issued + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) == 1 + assert "ListPromptsRequest" in str(deprecation_warnings[0].message) + + +@pytest.mark.anyio +async def test_list_prompts_with_untyped_param_triggers_warning() -> None: + """Test that handlers with untyped parameters trigger deprecation warnings.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_prompts() + async def handle_list_prompts(request) -> ListPromptsResult: # type: ignore[no-untyped-def] + return ListPromptsResult(prompts=[]) + + # A deprecation warning should be issued + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) == 1 + assert "ListPromptsRequest" in str(deprecation_warnings[0].message) + + +@pytest.mark.anyio +async def test_list_resources_with_typed_request_no_warning() -> None: + """Test that properly typed resource handlers don't trigger warnings.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_resources() + async def handle_list_resources(request: ListResourcesRequest) -> ListResourcesResult: + return ListResourcesResult(resources=[]) + + # No deprecation warning should be issued + assert len([warning for warning in w if issubclass(warning.category, DeprecationWarning)]) == 0 + + +@pytest.mark.anyio +async def test_list_resources_without_params_triggers_warning() -> None: + """Test that resource handlers without parameters trigger deprecation warnings.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_resources() + async def handle_list_resources() -> list[Resource]: + return [] + + # A deprecation warning should be issued + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) == 1 + assert "ListResourcesRequest" in str(deprecation_warnings[0].message) + + +@pytest.mark.anyio +async def test_list_tools_with_typed_request_no_warning() -> None: + """Test that properly typed tool handlers don't trigger warnings.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_tools() + async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: + return ListToolsResult(tools=[]) + + # No deprecation warning should be issued + assert len([warning for warning in w if issubclass(warning.category, DeprecationWarning)]) == 0 + + +@pytest.mark.anyio +async def test_list_tools_without_params_triggers_warning() -> None: + """Test that tool handlers without parameters trigger deprecation warnings.""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_tools() + async def handle_list_tools() -> list[Tool]: + return [] + + # A deprecation warning should be issued + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) == 1 + assert "ListToolsRequest" in str(deprecation_warnings[0].message) + + +@pytest.mark.anyio +async def test_old_style_handler_still_works() -> None: + """Test that old-style handlers still work (with deprecation warning).""" + server = Server("test") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @server.list_prompts() + async def handle_list_prompts() -> list[Prompt]: + return [Prompt(name="test", description="Test prompt")] + + # Handler should be registered + assert ListPromptsRequest in server.request_handlers + + # Deprecation warning should be issued + deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + assert len(deprecation_warnings) == 1 + + # Handler should still work correctly + handler = server.request_handlers[ListPromptsRequest] + request = ListPromptsRequest(method="prompts/list", params=None) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListPromptsResult) + assert len(result.root.prompts) == 1 + assert result.root.prompts[0].name == "test" + + +@pytest.mark.anyio +async def test_new_style_handler_receives_pagination_params() -> None: + """Test that new-style handlers receive pagination parameters correctly.""" + server = Server("test") + received_request: ListPromptsRequest | None = None + + @server.list_prompts() + async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: + nonlocal received_request + received_request = request + return ListPromptsResult(prompts=[], nextCursor="next-page") + + handler = server.request_handlers[ListPromptsRequest] + + # Test with cursor + cursor_value = "test-cursor-123" + request_with_cursor = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor=cursor_value)) + result = await handler(request_with_cursor) + + assert received_request is not None + assert received_request.params is not None + assert received_request.params.cursor == cursor_value + assert isinstance(result, ServerResult) + assert result.root.nextCursor == "next-page" diff --git a/tests/server/lowlevel/test_server_listing.py b/tests/server/lowlevel/test_server_listing.py index 9474edb3f..23ac7e451 100644 --- a/tests/server/lowlevel/test_server_listing.py +++ b/tests/server/lowlevel/test_server_listing.py @@ -1,5 +1,7 @@ """Basic tests for list_prompts, list_resources, and list_tools decorators without pagination.""" +import warnings + import pytest from pydantic import AnyUrl @@ -28,9 +30,12 @@ async def test_list_prompts_basic() -> None: Prompt(name="prompt2", description="Second prompt"), ] - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return test_prompts + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + @server.list_prompts() + async def handle_list_prompts() -> list[Prompt]: + return test_prompts handler = server.request_handlers[ListPromptsRequest] request = ListPromptsRequest(method="prompts/list", params=None) @@ -51,9 +56,12 @@ async def test_list_resources_basic() -> None: Resource(uri=AnyUrl("file:///test2.txt"), name="Test 2"), ] - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return test_resources + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + @server.list_resources() + async def handle_list_resources() -> list[Resource]: + return test_resources handler = server.request_handlers[ListResourcesRequest] request = ListResourcesRequest(method="resources/list", params=None) @@ -95,9 +103,12 @@ async def test_list_tools_basic() -> None: ), ] - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return test_tools + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + @server.list_tools() + async def handle_list_tools() -> list[Tool]: + return test_tools handler = server.request_handlers[ListToolsRequest] request = ListToolsRequest(method="tools/list", params=None) @@ -113,9 +124,12 @@ async def test_list_prompts_empty() -> None: """Test listing with empty results.""" server = Server("test") - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return [] + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + @server.list_prompts() + async def handle_list_prompts() -> list[Prompt]: + return [] handler = server.request_handlers[ListPromptsRequest] request = ListPromptsRequest(method="prompts/list", params=None) @@ -131,9 +145,12 @@ async def test_list_resources_empty() -> None: """Test listing with empty results.""" server = Server("test") - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return [] + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + @server.list_resources() + async def handle_list_resources() -> list[Resource]: + return [] handler = server.request_handlers[ListResourcesRequest] request = ListResourcesRequest(method="resources/list", params=None) @@ -149,9 +166,12 @@ async def test_list_tools_empty() -> None: """Test listing with empty results.""" server = Server("test") - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return [] + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + @server.list_tools() + async def handle_list_tools() -> list[Tool]: + return [] handler = server.request_handlers[ListToolsRequest] request = ListToolsRequest(method="tools/list", params=None) diff --git a/tests/server/lowlevel/test_type_accepts_request.py b/tests/server/lowlevel/test_type_accepts_request.py new file mode 100644 index 000000000..7a080ce91 --- /dev/null +++ b/tests/server/lowlevel/test_type_accepts_request.py @@ -0,0 +1,88 @@ +"""Unit tests for the type_accepts_request function.""" + +from typing import Any, TypeVar + +import pytest + +from mcp.server.lowlevel.func_inspection import type_accepts_request +from mcp.types import ListPromptsRequest, ListResourcesRequest, ListToolsRequest + + +@pytest.mark.parametrize( + "param_type,request_type,expected,description", + [ + # Exact type matches + (ListPromptsRequest, ListPromptsRequest, True, "exact type match"), + (ListToolsRequest, ListPromptsRequest, False, "different request type"), + (str, ListPromptsRequest, False, "string type"), + (int, ListPromptsRequest, False, "int type"), + (None, ListPromptsRequest, False, "None type"), + # Any type + (Any, ListPromptsRequest, True, "Any type accepts all"), + # Union types with request type + (ListPromptsRequest | None, ListPromptsRequest, True, "Optional request type"), + (str | ListPromptsRequest, ListPromptsRequest, True, "Union with request type (request second)"), + (ListPromptsRequest | str, ListPromptsRequest, True, "Union with request type (request first)"), + ( + ListPromptsRequest | ListToolsRequest, + ListPromptsRequest, + True, + "Union of multiple request types", + ), + # Union types without request type + (str | int, ListPromptsRequest, False, "Union of primitives"), + ( + ListToolsRequest | ListResourcesRequest, + ListPromptsRequest, + False, + "Union of different request types", + ), + (str | None, ListPromptsRequest, False, "Optional string"), + # Nested unions + ( + ListPromptsRequest | str | int, + ListPromptsRequest, + True, + "nested Union with request type", + ), + (str | int | bool, ListPromptsRequest, False, "nested Union without request type"), + # Generic types + (list[str], ListPromptsRequest, False, "generic list type"), + (list[ListPromptsRequest], ListPromptsRequest, False, "list of requests"), + ], +) +def test_type_accepts_request_simple( + param_type: Any, + request_type: type, + expected: bool, + description: str, +) -> None: + """Test type_accepts_request with simple type combinations.""" + assert type_accepts_request(param_type, request_type) is expected, f"Failed: {description}" + + +@pytest.mark.parametrize( + "typevar_factory,expected,description", + [ + # TypeVar with bounds + (lambda: TypeVar("BoundRequest", bound=ListPromptsRequest), True, "TypeVar bound to request type"), + (lambda: TypeVar("BoundString", bound=str), False, "TypeVar bound to different type"), + # TypeVar with constraints + ( + lambda: TypeVar("ConstrainedRequest", ListPromptsRequest, ListToolsRequest), + True, + "TypeVar constrained to include request type", + ), + (lambda: TypeVar("ConstrainedPrimitives", str, int), False, "TypeVar constrained to primitives"), + # TypeVar without bounds or constraints + (lambda: TypeVar("T"), False, "unbounded TypeVar"), + ], +) +def test_type_accepts_request_typevar( + typevar_factory: Any, + expected: bool, + description: str, +) -> None: + """Test type_accepts_request with TypeVar types.""" + param_type = typevar_factory() + assert type_accepts_request(param_type, ListPromptsRequest) is expected, f"Failed: {description}" From cbb0e37e956e4bb8d9faf9434b8efc882f7756a1 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 18 Sep 2025 22:25:14 +0100 Subject: [PATCH 10/11] feat: change to request injection on type rather than positional --- src/mcp/server/lowlevel/func_inspection.py | 255 ++------- src/mcp/server/lowlevel/server.py | 87 ++- .../lowlevel/test_advanced_type_inspection.py | 193 ------- .../lowlevel/test_deprecation_warnings.py | 390 +++++++------- tests/server/lowlevel/test_func_inspection.py | 497 ++++++++++++------ .../lowlevel/test_type_accepts_request.py | 88 ---- 6 files changed, 628 insertions(+), 882 deletions(-) delete mode 100644 tests/server/lowlevel/test_advanced_type_inspection.py delete mode 100644 tests/server/lowlevel/test_type_accepts_request.py diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py index 1e6d862fa..8bd69ec22 100644 --- a/src/mcp/server/lowlevel/func_inspection.py +++ b/src/mcp/server/lowlevel/func_inspection.py @@ -1,217 +1,72 @@ import inspect -import types import warnings from collections.abc import Callable -from typing import Any, TypeVar, Union, get_args, get_origin +from typing import Any, TypeVar, get_type_hints -def accepts_single_positional_arg(func: Callable[..., Any]) -> bool: +def issue_deprecation_warning(func: Callable[..., Any], request_type: type) -> None: """ - True if the function accepts at least one positional argument, otherwise false. - - This function intentionally does not define behavior for `func`s that - contain more than one positional argument, or any required keyword - arguments without defaults. + Issue a deprecation warning for handlers that don't use the new request parameter style. """ - try: - sig = inspect.signature(func) - except (ValueError, TypeError): - return False - - params = dict(sig.parameters.items()) - - if len(params) == 0: - # No parameters at all - can't accept single argument - return False - - # Check if ALL remaining parameters are keyword-only - all_keyword_only = all(param.kind == inspect.Parameter.KEYWORD_ONLY for param in params.values()) - - if all_keyword_only: - # If all params are keyword-only, check if they ALL have defaults - # If they do, the function can be called with no arguments -> no argument - all_have_defaults = all(param.default is not inspect.Parameter.empty for param in params.values()) - if all_have_defaults: - return False - # otherwise, undefined (doesn't accept a positional argument, and requires at least one keyword only) + func_name = getattr(func, "__name__", str(func)) + warnings.warn( + f"Handler '{func_name}' should accept a '{request_type.__name__}' parameter. " + "Support for handlers without this parameter will be removed in a future version.", + DeprecationWarning, + stacklevel=4, + ) - # Check if the ONLY parameter is **kwargs (VAR_KEYWORD) - # A function with only **kwargs can't accept a positional argument - if len(params) == 1: - only_param = next(iter(params.values())) - if only_param.kind == inspect.Parameter.VAR_KEYWORD: - return False # Can't pass positional argument to **kwargs - # Has at least one positional or variadic parameter - can accept argument - # Important note: this is designed to _not_ handle the situation where - # there are multiple keyword only arguments with no defaults. In those - # situations it's an invalid handler function, and will error. But it's - # not the responsibility of this function to check the validity of a - # callback. - return True +T = TypeVar("T") +R = TypeVar("R") -def get_first_parameter_type(func: Callable[..., Any]) -> Any: +def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> tuple[Callable[[T], R], bool]: """ - Get the type annotation of the first parameter of a function. + Create a wrapper function that knows how to call func with the request object. - Returns None if: - - The function has no parameters - - The first parameter has no type annotation - - The signature cannot be inspected + Returns a tuple of (wrapper_func, should_deprecate): + - wrapper_func: A function that takes the request and calls func appropriately + - should_deprecate: True if a deprecation warning should be issued - Returns the actual annotation otherwise (could be a type, Any, Union, TypeVar, etc.) + The wrapper handles three calling patterns: + 1. Positional-only parameter typed as request_type (no default): func(req) + 2. Positional/keyword parameter typed as request_type (no default): func(**{param_name: req}) + 3. No request parameter or parameter with default (deprecated): func() """ try: sig = inspect.signature(func) - except (ValueError, TypeError): - return None - - params = list(sig.parameters.values()) - if not params: - return None - - first_param = params[0] - - # Skip *args and **kwargs - if first_param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): - return None - - annotation = first_param.annotation - if annotation == inspect.Parameter.empty: - return None - - return annotation - - -def type_accepts_request(param_type: Any, request_type: type) -> bool: - """ - Check if a parameter type annotation can accept the request type. - - Handles: - - Exact type match - - Union types (checks if request_type is in the Union) - - TypeVars (checks if request_type matches the bound or constraints) - - Generic types (basic support) - - Any (always returns True) - - Returns False for None or incompatible types. - """ - if param_type is None: - return False - - # Check for Any type - if param_type is Any: - return True - - # Exact match - if param_type == request_type: - return True - - # Handle Union types (both typing.Union and | syntax) - origin = get_origin(param_type) - if origin is Union or origin is types.UnionType: - args = get_args(param_type) - # Check if request_type is in the Union - for arg in args: - if arg == request_type: - return True - # Recursively check each union member - if type_accepts_request(arg, request_type): - return True - return False - - # Handle TypeVar - if isinstance(param_type, TypeVar): - # Check if request_type matches the bound - if param_type.__bound__ is not None: - if request_type == param_type.__bound__: - return True - # Check if request_type is a subclass of the bound - try: - if issubclass(request_type, param_type.__bound__): - return True - except TypeError: - pass - - # Check constraints - if param_type.__constraints__: - for constraint in param_type.__constraints__: - if request_type == constraint: - return True - try: - if issubclass(request_type, constraint): - return True - except TypeError: - pass - - return False - - # For other generic types, check if request_type matches the origin - if origin is not None: - # Get the base generic type (e.g., list from list[str]) - return request_type == origin - - return False - - -def should_pass_request(func: Callable[..., Any], request_type: type) -> tuple[bool, bool]: - """ - Determine if a request should be passed to the function based on parameter type inspection. - - Returns a tuple of (should_pass_request, should_deprecate): - - should_pass_request: True if the request should be passed to the function - - should_deprecate: True if a deprecation warning should be issued - - The decision logic: - 1. If the function has no parameters -> (False, True) - old style without params, deprecate - 2. If the function has parameters but can't accept positional args -> (False, False) - 3. If the first parameter type accepts the request type -> (True, False) - pass request, no deprecation - 4. If the first parameter is typed as Any -> (True, True) - pass request but deprecate (effectively untyped) - 5. If the first parameter is typed with something incompatible -> (False, True) - old style, deprecate - 6. If the first parameter is untyped but accepts positional args -> (True, True) - pass request, deprecate - """ - can_accept_arg = accepts_single_positional_arg(func) - - if not can_accept_arg: - # Check if it has no parameters at all (old style) - try: - sig = inspect.signature(func) - if len(sig.parameters) == 0: - # Old style handler with no parameters - don't pass request but deprecate - return False, True - except (ValueError, TypeError): - pass - # Can't accept positional arguments for other reasons - return False, False - - param_type = get_first_parameter_type(func) - - if param_type is None: - # Untyped parameter - this is the old style, pass request but deprecate - return True, True - - # Check if the parameter type can accept the request - if type_accepts_request(param_type, request_type): - # Check if it's Any - if so, we should deprecate - if param_type is Any: - return True, True - # Properly typed to accept the request - pass request, no deprecation - return True, False - - # Parameter is typed with something incompatible - this is an old style handler expecting - # a different signature, don't pass request, issue deprecation - return False, True - - -def issue_deprecation_warning(func: Callable[..., Any], request_type: type) -> None: - """ - Issue a deprecation warning for handlers that don't use the new request parameter style. - """ - func_name = getattr(func, "__name__", str(func)) - warnings.warn( - f"Handler '{func_name}' should accept a '{request_type.__name__}' parameter. " - "Support for handlers without this parameter will be removed in a future version.", - DeprecationWarning, - stacklevel=4, - ) + type_hints = get_type_hints(func) + except (ValueError, TypeError, NameError): + # Can't inspect signature or resolve type hints, assume no request parameter (deprecated) + return lambda _: func(), True + + # Check for positional-only parameter typed as request_type + for param_name, param in sig.parameters.items(): + if param.kind == inspect.Parameter.POSITIONAL_ONLY: + param_type = type_hints.get(param_name) + if param_type == request_type: + # Check if it has a default - if so, treat as old style (deprecated) + if param.default is not inspect.Parameter.empty: + return lambda _: func(), True + # Found positional-only parameter with correct type and no default + return lambda req: func(req), False + + # Check for any positional/keyword parameter typed as request_type + for param_name, param in sig.parameters.items(): + if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY): + param_type = type_hints.get(param_name) + if param_type == request_type: + # Check if it has a default - if so, treat as old style (deprecated) + if param.default is not inspect.Parameter.empty: + return lambda _: func(), True + + # Found keyword parameter with correct type and no default + # Need to capture param_name in closure properly + def make_keyword_wrapper(name: str) -> Callable[[Any], Any]: + return lambda req: func(**{name: req}) + + return make_keyword_wrapper(param_name), False + + # No request parameter found - use old style (deprecated) + return lambda _: func(), True diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index ac61849bc..c3b978ecc 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -11,7 +11,7 @@ 2. Define request handlers using decorators: @server.list_prompts() - async def handle_list_prompts() -> list[types.Prompt]: + async def handle_list_prompts(request: types.ListPromptsRequest) -> types.ListPromptsResult: # Implementation @server.get_prompt() @@ -21,7 +21,7 @@ async def handle_get_prompt( # Implementation @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: + async def handle_list_tools(request: types.ListToolsRequest) -> types.ListToolsResult: # Implementation @server.call_tool() @@ -82,7 +82,7 @@ async def main(): from typing_extensions import TypeVar import mcp.types as types -from mcp.server.lowlevel.func_inspection import issue_deprecation_warning, should_pass_request +from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -235,28 +235,22 @@ def decorator( | Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], ): logger.debug("Registering handler for PromptListRequest") - pass_request, should_deprecate = should_pass_request(func, types.ListPromptsRequest) - if should_deprecate: - issue_deprecation_warning(func, types.ListPromptsRequest) + # Create wrapper that knows how to call func with the request + wrapper, _ = create_call_wrapper(func, types.ListPromptsRequest) - if pass_request: - request_func = cast(Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], func) + # if should_deprecate: + # issue_deprecation_warning(func, types.ListPromptsRequest) - async def request_handler(req: types.ListPromptsRequest): - result = await request_func(req) + async def handler(req: types.ListPromptsRequest): + result = await wrapper(req) + # Handle both old style (list[Prompt]) and new style (ListPromptsResult) + if isinstance(result, types.ListPromptsResult): return types.ServerResult(result) - - handler = request_handler - else: - list_func = cast(Callable[[], Awaitable[list[types.Prompt]]], func) - - async def list_handler(_: types.ListPromptsRequest): - result = await list_func() + else: + # Old style returns list[Prompt] return types.ServerResult(types.ListPromptsResult(prompts=result)) - handler = list_handler - self.request_handlers[types.ListPromptsRequest] = handler return func @@ -283,28 +277,23 @@ def decorator( | Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], ): logger.debug("Registering handler for ListResourcesRequest") - pass_request, should_deprecate = should_pass_request(func, types.ListResourcesRequest) - if should_deprecate: - issue_deprecation_warning(func, types.ListResourcesRequest) + # Create wrapper that knows how to call func with the request + wrapper, _ = create_call_wrapper(func, types.ListResourcesRequest) - if pass_request: - request_func = cast(Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], func) + # TODO: Decide whether we want this sort of deprecation in a later PR + # if should_deprecate: + # issue_deprecation_warning(func, types.ListResourcesRequest) - async def request_handler(req: types.ListResourcesRequest): - result = await request_func(req) + async def handler(req: types.ListResourcesRequest): + result = await wrapper(req) + # Handle both old style (list[Resource]) and new style (ListResourcesResult) + if isinstance(result, types.ListResourcesResult): return types.ServerResult(result) - - handler = request_handler - else: - list_func = cast(Callable[[], Awaitable[list[types.Resource]]], func) - - async def list_handler(_: types.ListResourcesRequest): - result = await list_func() + else: + # Old style returns list[Resource] return types.ServerResult(types.ListResourcesResult(resources=result)) - handler = list_handler - self.request_handlers[types.ListResourcesRequest] = handler return func @@ -426,35 +415,31 @@ def decorator( | Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], ): logger.debug("Registering handler for ListToolsRequest") - pass_request, should_deprecate = should_pass_request(func, types.ListToolsRequest) - if should_deprecate: - issue_deprecation_warning(func, types.ListToolsRequest) + # Create wrapper that knows how to call func with the request + wrapper, _ = create_call_wrapper(func, types.ListToolsRequest) + + # TODO: Decide whether we want this sort of deprecation in a later PR + # if should_deprecate: + # issue_deprecation_warning(func, types.ListToolsRequest) - if pass_request: - request_func = cast(Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], func) + async def handler(req: types.ListToolsRequest): + result = await wrapper(req) - async def request_handler(req: types.ListToolsRequest): - result = await request_func(req) + # Handle both old style (list[Tool]) and new style (ListToolsResult) + if isinstance(result, types.ListToolsResult): # Refresh the tool cache with returned tools for tool in result.tools: self._tool_cache[tool.name] = tool return types.ServerResult(result) - - handler = request_handler - else: - list_func = cast(Callable[[], Awaitable[list[types.Tool]]], func) - - async def list_handler(req: types.ListToolsRequest): - result = await list_func() + else: + # Old style returns list[Tool] # Clear and refresh the entire tool cache self._tool_cache.clear() for tool in result: self._tool_cache[tool.name] = tool return types.ServerResult(types.ListToolsResult(tools=result)) - handler = list_handler - self.request_handlers[types.ListToolsRequest] = handler return func diff --git a/tests/server/lowlevel/test_advanced_type_inspection.py b/tests/server/lowlevel/test_advanced_type_inspection.py deleted file mode 100644 index 0dcd27941..000000000 --- a/tests/server/lowlevel/test_advanced_type_inspection.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Tests for advanced type inspection features in pagination handlers.""" - -import warnings -from typing import Any, TypeVar - -import pytest - -from mcp.server import Server -from mcp.types import ( - ListPromptsRequest, - ListPromptsResult, - ListToolsRequest, - Prompt, - ServerResult, -) - -# Define TypeVars for testing -T = TypeVar("T") -ConstrainedRequest = TypeVar("ConstrainedRequest", ListPromptsRequest, ListToolsRequest) - - -@pytest.mark.anyio -async def test_union_type_with_request_no_warning() -> None: - """Test that Union types containing the request type don't trigger warnings.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_prompts() - async def handle_list_prompts(request: ListPromptsRequest | None) -> ListPromptsResult: - assert request is not None - return ListPromptsResult(prompts=[]) - - # No deprecation warning should be issued for Union containing request type - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] - assert len(deprecation_warnings) == 0 - - -@pytest.mark.anyio -async def test_union_type_multiple_requests_no_warning() -> None: - """Test Union with multiple request types works correctly.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_prompts() - async def handle_list_prompts( - request: ListPromptsRequest | ListToolsRequest, - ) -> ListPromptsResult: - assert isinstance(request, ListPromptsRequest) - return ListPromptsResult(prompts=[]) - - # No deprecation warning - Union contains the request type - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] - assert len(deprecation_warnings) == 0 - - -@pytest.mark.anyio -async def test_any_type_triggers_warning() -> None: - """Test that Any type triggers deprecation warning.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_prompts() - async def handle_list_prompts(request: Any) -> ListPromptsResult: - return ListPromptsResult(prompts=[]) - - # Deprecation warning should be issued for Any type (effectively untyped) - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] - assert len(deprecation_warnings) == 1 - assert "ListPromptsRequest" in str(deprecation_warnings[0].message) - - -@pytest.mark.anyio -async def test_typevar_with_bound_no_warning() -> None: - """Test that TypeVar with matching bound doesn't trigger warning.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - bound_request = TypeVar("bound_request", bound=ListPromptsRequest) - - @server.list_prompts() - async def handle_list_prompts(request: bound_request) -> ListPromptsResult: # type: ignore[reportInvalidTypeVarUse] - return ListPromptsResult(prompts=[]) - - # No warning - TypeVar is bound to ListPromptsRequest - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] - assert len(deprecation_warnings) == 0 - - -@pytest.mark.anyio -async def test_typevar_with_constraints_no_warning() -> None: - """Test that TypeVar with matching constraint doesn't trigger warning.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_prompts() - async def handle_list_prompts(request: ConstrainedRequest) -> ListPromptsResult: - return ListPromptsResult(prompts=[]) - - # No warning - TypeVar has ListPromptsRequest as a constraint - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] - assert len(deprecation_warnings) == 0 - - -@pytest.mark.anyio -async def test_any_type_still_receives_request() -> None: - """Test that handlers with Any type still receive the request object.""" - server = Server("test") - received_request: Any = None - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_prompts() - async def handle_list_prompts(request: Any) -> ListPromptsResult: - nonlocal received_request - received_request = request - return ListPromptsResult(prompts=[]) - - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) - - assert received_request is not None - assert isinstance(received_request, ListPromptsRequest) - assert isinstance(result, ServerResult) - - -@pytest.mark.anyio -async def test_union_handler_receives_correct_request() -> None: - """Test that Union-typed handlers receive the request correctly.""" - server = Server("test") - received_request: ListPromptsRequest | None = None - - @server.list_prompts() - async def handle_list_prompts(request: ListPromptsRequest | None) -> ListPromptsResult: - nonlocal received_request - received_request = request - return ListPromptsResult(prompts=[Prompt(name="test")]) - - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) - - assert received_request is not None - assert isinstance(received_request, ListPromptsRequest) - assert isinstance(result, ServerResult) - assert isinstance(result.root, ListPromptsResult) - assert len(result.root.prompts) == 1 - - -@pytest.mark.anyio -async def test_wrong_union_type_triggers_warning() -> None: - """Test that Union without the request type triggers deprecation warning.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_prompts() # type: ignore[arg-type] # Intentionally testing incorrect type for deprecation warning - async def handle_list_prompts(request: str | int) -> list[Prompt]: - return [] - - # Deprecation warning should be issued - Union doesn't contain request type - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] - assert len(deprecation_warnings) == 1 - - -@pytest.mark.anyio -async def test_generic_typevar_no_warning() -> None: - """Test that generic TypeVar doesn't trigger warning.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_prompts() - async def handle_list_prompts(request: T) -> ListPromptsResult: # type: ignore[valid-type] - return ListPromptsResult(prompts=[]) - - # Generic TypeVar without bounds - should not trigger warning but will receive request - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] - # This may or may not warn depending on implementation - the key is it shouldn't break - assert len(deprecation_warnings) in [0, 1] # Either way is acceptable diff --git a/tests/server/lowlevel/test_deprecation_warnings.py b/tests/server/lowlevel/test_deprecation_warnings.py index 08346c9a3..f030b9459 100644 --- a/tests/server/lowlevel/test_deprecation_warnings.py +++ b/tests/server/lowlevel/test_deprecation_warnings.py @@ -1,194 +1,196 @@ -import warnings - -import pytest - -from mcp.server import Server -from mcp.types import ( - ListPromptsRequest, - ListPromptsResult, - ListResourcesRequest, - ListResourcesResult, - ListToolsRequest, - ListToolsResult, - PaginatedRequestParams, - Prompt, - Resource, - ServerResult, - Tool, -) - - -@pytest.mark.anyio -async def test_list_prompts_with_typed_request_no_warning() -> None: - """Test that properly typed handlers don't trigger deprecation warnings.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_prompts() - async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: - return ListPromptsResult(prompts=[]) - - # No deprecation warning should be issued - assert len([warning for warning in w if issubclass(warning.category, DeprecationWarning)]) == 0 - - -@pytest.mark.anyio -async def test_list_prompts_without_params_triggers_warning() -> None: - """Test that handlers without parameters trigger deprecation warnings.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return [] - - # A deprecation warning should be issued - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] - assert len(deprecation_warnings) == 1 - assert "ListPromptsRequest" in str(deprecation_warnings[0].message) - - -@pytest.mark.anyio -async def test_list_prompts_with_untyped_param_triggers_warning() -> None: - """Test that handlers with untyped parameters trigger deprecation warnings.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_prompts() - async def handle_list_prompts(request) -> ListPromptsResult: # type: ignore[no-untyped-def] - return ListPromptsResult(prompts=[]) - - # A deprecation warning should be issued - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] - assert len(deprecation_warnings) == 1 - assert "ListPromptsRequest" in str(deprecation_warnings[0].message) - - -@pytest.mark.anyio -async def test_list_resources_with_typed_request_no_warning() -> None: - """Test that properly typed resource handlers don't trigger warnings.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_resources() - async def handle_list_resources(request: ListResourcesRequest) -> ListResourcesResult: - return ListResourcesResult(resources=[]) - - # No deprecation warning should be issued - assert len([warning for warning in w if issubclass(warning.category, DeprecationWarning)]) == 0 - - -@pytest.mark.anyio -async def test_list_resources_without_params_triggers_warning() -> None: - """Test that resource handlers without parameters trigger deprecation warnings.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return [] - - # A deprecation warning should be issued - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] - assert len(deprecation_warnings) == 1 - assert "ListResourcesRequest" in str(deprecation_warnings[0].message) - - -@pytest.mark.anyio -async def test_list_tools_with_typed_request_no_warning() -> None: - """Test that properly typed tool handlers don't trigger warnings.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_tools() - async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: - return ListToolsResult(tools=[]) - - # No deprecation warning should be issued - assert len([warning for warning in w if issubclass(warning.category, DeprecationWarning)]) == 0 - - -@pytest.mark.anyio -async def test_list_tools_without_params_triggers_warning() -> None: - """Test that tool handlers without parameters trigger deprecation warnings.""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return [] - - # A deprecation warning should be issued - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] - assert len(deprecation_warnings) == 1 - assert "ListToolsRequest" in str(deprecation_warnings[0].message) - - -@pytest.mark.anyio -async def test_old_style_handler_still_works() -> None: - """Test that old-style handlers still work (with deprecation warning).""" - server = Server("test") - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return [Prompt(name="test", description="Test prompt")] - - # Handler should be registered - assert ListPromptsRequest in server.request_handlers - - # Deprecation warning should be issued - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] - assert len(deprecation_warnings) == 1 - - # Handler should still work correctly - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result.root, ListPromptsResult) - assert len(result.root.prompts) == 1 - assert result.root.prompts[0].name == "test" - - -@pytest.mark.anyio -async def test_new_style_handler_receives_pagination_params() -> None: - """Test that new-style handlers receive pagination parameters correctly.""" - server = Server("test") - received_request: ListPromptsRequest | None = None - - @server.list_prompts() - async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: - nonlocal received_request - received_request = request - return ListPromptsResult(prompts=[], nextCursor="next-page") - - handler = server.request_handlers[ListPromptsRequest] - - # Test with cursor - cursor_value = "test-cursor-123" - request_with_cursor = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor=cursor_value)) - result = await handler(request_with_cursor) - - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == cursor_value - assert isinstance(result, ServerResult) - assert result.root.nextCursor == "next-page" +# TODO: Decide whether we want deprecation warnings in another PR +# import warnings +# +# import pytest +# +# from mcp.server import Server +# from mcp.types import ( +# ListPromptsRequest, +# ListPromptsResult, +# ListResourcesRequest, +# ListResourcesResult, +# ListToolsRequest, +# ListToolsResult, +# PaginatedRequestParams, +# Prompt, +# Resource, +# ServerResult, +# Tool, +# ) +# +# +# @pytest.mark.anyio +# async def test_list_prompts_with_typed_request_no_warning() -> None: +# """Test that properly typed handlers don't trigger deprecation warnings.""" +# server = Server("test") +# +# with warnings.catch_warnings(record=True) as w: +# warnings.simplefilter("always") +# +# @server.list_prompts() +# async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: +# return ListPromptsResult(prompts=[]) +# +# # No deprecation warning should be issued +# assert len([warning for warning in w if issubclass(warning.category, DeprecationWarning)]) == 0 +# +# +# @pytest.mark.anyio +# async def test_list_prompts_without_params_triggers_warning() -> None: +# """Test that handlers without parameters trigger deprecation warnings.""" +# server = Server("test") +# +# with warnings.catch_warnings(record=True) as w: +# warnings.simplefilter("always") +# +# @server.list_prompts() +# async def handle_list_prompts() -> list[Prompt]: +# return [] +# +# # A deprecation warning should be issued +# deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] +# assert len(deprecation_warnings) == 1 +# assert "ListPromptsRequest" in str(deprecation_warnings[0].message) +# +# +# @pytest.mark.anyio +# async def test_list_prompts_with_untyped_param_triggers_warning() -> None: +# """Test that handlers with untyped parameters trigger deprecation warnings.""" +# server = Server("test") +# +# with warnings.catch_warnings(record=True) as w: +# warnings.simplefilter("always") +# +# @server.list_prompts() +# async def handle_list_prompts(request) -> ListPromptsResult: # type: ignore[no-untyped-def] +# return ListPromptsResult(prompts=[]) +# +# # A deprecation warning should be issued +# deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] +# assert len(deprecation_warnings) == 1 +# assert "ListPromptsRequest" in str(deprecation_warnings[0].message) +# +# +# @pytest.mark.anyio +# async def test_list_resources_with_typed_request_no_warning() -> None: +# """Test that properly typed resource handlers don't trigger warnings.""" +# server = Server("test") +# +# with warnings.catch_warnings(record=True) as w: +# warnings.simplefilter("always") +# +# @server.list_resources() +# async def handle_list_resources(request: ListResourcesRequest) -> ListResourcesResult: +# return ListResourcesResult(resources=[]) +# +# # No deprecation warning should be issued +# assert len([warning for warning in w if issubclass(warning.category, DeprecationWarning)]) == 0 +# +# +# @pytest.mark.anyio +# async def test_list_resources_without_params_triggers_warning() -> None: +# """Test that resource handlers without parameters trigger deprecation warnings.""" +# server = Server("test") +# +# with warnings.catch_warnings(record=True) as w: +# warnings.simplefilter("always") +# +# @server.list_resources() +# async def handle_list_resources() -> list[Resource]: +# return [] +# +# # A deprecation warning should be issued +# deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] +# assert len(deprecation_warnings) == 1 +# assert "ListResourcesRequest" in str(deprecation_warnings[0].message) +# +# +# @pytest.mark.anyio +# async def test_list_tools_with_typed_request_no_warning() -> None: +# """Test that properly typed tool handlers don't trigger warnings.""" +# server = Server("test") +# +# with warnings.catch_warnings(record=True) as w: +# warnings.simplefilter("always") +# +# @server.list_tools() +# async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: +# return ListToolsResult(tools=[]) +# +# # No deprecation warning should be issued +# assert len([warning for warning in w if issubclass(warning.category, DeprecationWarning)]) == 0 +# +# +# @pytest.mark.anyio +# async def test_list_tools_without_params_triggers_warning() -> None: +# """Test that tool handlers without parameters trigger deprecation warnings.""" +# server = Server("test") +# +# with warnings.catch_warnings(record=True) as w: +# warnings.simplefilter("always") +# +# @server.list_tools() +# async def handle_list_tools() -> list[Tool]: +# return [] +# +# # A deprecation warning should be issued +# deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] +# assert len(deprecation_warnings) == 1 +# assert "ListToolsRequest" in str(deprecation_warnings[0].message) +# +# +# @pytest.mark.anyio +# async def test_old_style_handler_still_works() -> None: +# """Test that old-style handlers still work (with deprecation warning).""" +# server = Server("test") +# +# with warnings.catch_warnings(record=True) as w: +# warnings.simplefilter("always") +# +# @server.list_prompts() +# async def handle_list_prompts() -> list[Prompt]: +# return [Prompt(name="test", description="Test prompt")] +# +# # Handler should be registered +# assert ListPromptsRequest in server.request_handlers +# +# # Deprecation warning should be issued +# deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] +# assert len(deprecation_warnings) == 1 +# +# # Handler should still work correctly +# handler = server.request_handlers[ListPromptsRequest] +# request = ListPromptsRequest(method="prompts/list", params=None) +# result = await handler(request) +# +# assert isinstance(result, ServerResult) +# assert isinstance(result.root, ListPromptsResult) +# assert len(result.root.prompts) == 1 +# assert result.root.prompts[0].name == "test" +# +# +# @pytest.mark.anyio +# async def test_new_style_handler_receives_pagination_params() -> None: +# """Test that new-style handlers receive pagination parameters correctly.""" +# server = Server("test") +# received_request: ListPromptsRequest | None = None +# +# @server.list_prompts() +# async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: +# nonlocal received_request +# received_request = request +# return ListPromptsResult(prompts=[], nextCursor="next-page") +# +# handler = server.request_handlers[ListPromptsRequest] +# +# # Test with cursor +# cursor_value = "test-cursor-123" +# request_with_cursor = ListPromptsRequest(method="prompts/list", +# params=PaginatedRequestParams(cursor=cursor_value)) +# result = await handler(request_with_cursor) +# +# assert received_request is not None +# assert received_request.params is not None +# assert received_request.params.cursor == cursor_value +# assert isinstance(result, ServerResult) +# assert result.root.nextCursor == "next-page" diff --git a/tests/server/lowlevel/test_func_inspection.py b/tests/server/lowlevel/test_func_inspection.py index d05538e52..f2e7233c0 100644 --- a/tests/server/lowlevel/test_func_inspection.py +++ b/tests/server/lowlevel/test_func_inspection.py @@ -1,176 +1,361 @@ -from collections.abc import Callable -from typing import Any +"""Unit tests for func_inspection module. + +Tests the create_call_wrapper function which determines how to call handler functions +with different parameter signatures and type hints. +""" + +from typing import Any, Generic, TypeVar import pytest -from mcp import types -from mcp.server.lowlevel.func_inspection import accepts_single_positional_arg +from mcp.server.lowlevel.func_inspection import create_call_wrapper +from mcp.types import ListPromptsRequest, ListResourcesRequest, ListToolsRequest, PaginatedRequestParams + +T = TypeVar("T") + + +@pytest.mark.anyio +async def test_no_params_returns_deprecated_wrapper() -> None: + """Test: def foo() - should call without request and mark as deprecated.""" + called_without_request = False + + async def handler() -> list[str]: + nonlocal called_without_request + called_without_request = True + return ["test"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + + assert should_deprecate is True + + # Wrapper should call handler without passing request + request = ListPromptsRequest(method="prompts/list", params=None) + result = await wrapper(request) + assert called_without_request is True + assert result == ["test"] + + +@pytest.mark.anyio +async def test_param_with_default_returns_deprecated_wrapper() -> None: + """Test: def foo(thing: int = 1) - should call without request and mark as deprecated.""" + called_without_request = False + + async def handler(thing: int = 1) -> list[str]: + nonlocal called_without_request + called_without_request = True + return [f"test-{thing}"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + + assert should_deprecate is True + + # Wrapper should call handler without passing request (uses default value) + request = ListPromptsRequest(method="prompts/list", params=None) + result = await wrapper(request) + assert called_without_request is True + assert result == ["test-1"] + + +@pytest.mark.anyio +async def test_typed_request_param_passes_request() -> None: + """Test: def foo(req: ListPromptsRequest) - should pass request through.""" + received_request = None + + async def handler(req: ListPromptsRequest) -> list[str]: + nonlocal received_request + received_request = req + return ["test"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + + assert should_deprecate is False + + # Wrapper should pass request to handler + request = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor="test-cursor")) + await wrapper(request) + + assert received_request is not None + assert received_request is request + params = getattr(received_request, "params", None) + assert params is not None + assert params.cursor == "test-cursor" + + +@pytest.mark.anyio +async def test_typed_request_with_default_param_passes_request() -> None: + """Test: def foo(req: ListPromptsRequest, thing: int = 1) - should pass request through.""" + received_request = None + received_thing = None + + async def handler(req: ListPromptsRequest, thing: int = 1) -> list[str]: + nonlocal received_request, received_thing + received_request = req + received_thing = thing + return ["test"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + + assert should_deprecate is False + + # Wrapper should pass request to handler + request = ListPromptsRequest(method="prompts/list", params=None) + await wrapper(request) + + assert received_request is request + assert received_thing == 1 # default value + + +@pytest.mark.anyio +async def test_optional_typed_request_with_default_none_is_deprecated() -> None: + """Test: def foo(thing: int = 1, req: ListPromptsRequest | None = None) - deprecated.""" + called_without_request = False + + async def handler(thing: int = 1, req: ListPromptsRequest | None = None) -> list[str]: + nonlocal called_without_request + called_without_request = True + return ["test"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + + # Because req has a default value (None), it's treated as deprecated + assert should_deprecate is True + + # Wrapper should call handler without passing request + request = ListPromptsRequest(method="prompts/list", params=None) + result = await wrapper(request) + assert called_without_request is True + assert result == ["test"] + + +@pytest.mark.anyio +async def test_untyped_request_param_is_deprecated() -> None: + """Test: def foo(req) - should call without request and mark as deprecated.""" + called = False + + async def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[reportMissingParameterType] + nonlocal called + called = True + return ["test"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) # pyright: ignore[reportUnknownArgumentType] + + assert should_deprecate is True + + # Wrapper should call handler without passing request, which will fail because req is required + request = ListPromptsRequest(method="prompts/list", params=None) + # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it + with pytest.raises(TypeError, match="missing 1 required positional argument"): + await wrapper(request) + + +@pytest.mark.anyio +async def test_any_typed_request_param_is_deprecated() -> None: + """Test: def foo(req: Any) - should call without request and mark as deprecated.""" + + async def handler(req: Any) -> list[str]: + return ["test"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + + assert should_deprecate is True + + # Wrapper should call handler without passing request, which will fail because req is required + request = ListPromptsRequest(method="prompts/list", params=None) + # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it + with pytest.raises(TypeError, match="missing 1 required positional argument"): + await wrapper(request) + + +@pytest.mark.anyio +async def test_generic_typed_request_param_is_deprecated() -> None: + """Test: def foo(req: Generic[T]) - should call without request and mark as deprecated.""" + + async def handler(req: Generic[T]) -> list[str]: # pyright: ignore[reportGeneralTypeIssues] + return ["test"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + + assert should_deprecate is True + + # Wrapper should call handler without passing request, which will fail because req is required + request = ListPromptsRequest(method="prompts/list", params=None) + # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it + with pytest.raises(TypeError, match="missing 1 required positional argument"): + await wrapper(request) + + +@pytest.mark.anyio +async def test_wrong_typed_request_param_is_deprecated() -> None: + """Test: def foo(req: str) - should call without request and mark as deprecated.""" + + async def handler(req: str) -> list[str]: + return ["test"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + + assert should_deprecate is True + + # Wrapper should call handler without passing request, which will fail because req is required + request = ListPromptsRequest(method="prompts/list", params=None) + # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it + with pytest.raises(TypeError, match="missing 1 required positional argument"): + await wrapper(request) + + +@pytest.mark.anyio +async def test_required_param_before_typed_request_attempts_to_pass() -> None: + """Test: def foo(thing: int, req: ListPromptsRequest) - attempts to pass request (will fail at runtime).""" + received_request = None + + async def handler(thing: int, req: ListPromptsRequest) -> list[str]: + nonlocal received_request + received_request = req + return ["test"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + + # Not marked as deprecated because it has the correct type hint + assert should_deprecate is False + + # Wrapper will attempt to pass request, but it will fail at runtime + # because 'thing' is required and has no default + request = ListPromptsRequest(method="prompts/list", params=None) + + # This will raise TypeError because 'thing' is missing + with pytest.raises(TypeError, match="missing 1 required positional argument: 'thing'"): + await wrapper(request) + + +@pytest.mark.anyio +async def test_positional_only_param_with_correct_type() -> None: + """Test: def foo(req: ListPromptsRequest, /) - should pass request through.""" + received_request = None + + async def handler(req: ListPromptsRequest, /) -> list[str]: + nonlocal received_request + received_request = req + return ["test"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + + assert should_deprecate is False + + # Wrapper should pass request to handler + request = ListPromptsRequest(method="prompts/list", params=None) + await wrapper(request) + + assert received_request is request + + +def test_positional_only_param_with_default_is_deprecated() -> None: + """Test: def foo(req: ListPromptsRequest = None, /) - deprecated due to default value.""" + + async def handler(req: ListPromptsRequest = None, /) -> list[str]: # type: ignore[assignment] + return ["test"] + + _wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + + # Has default value, so treated as deprecated + assert should_deprecate is True + + +@pytest.mark.anyio +async def test_keyword_only_param_with_correct_type() -> None: + """Test: def foo(*, req: ListPromptsRequest) - should pass request through.""" + received_request = None + + async def handler(*, req: ListPromptsRequest) -> list[str]: + nonlocal received_request + received_request = req + return ["test"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + + assert should_deprecate is False + + # Wrapper should pass request to handler with keyword argument + request = ListPromptsRequest(method="prompts/list", params=None) + await wrapper(request) + + assert received_request is request + + +@pytest.mark.anyio +async def test_different_request_types() -> None: + """Test that wrapper works with different request types.""" + # Test with ListResourcesRequest + received_request = None + + async def handler(req: ListResourcesRequest) -> list[str]: + nonlocal received_request + received_request = req + return ["test"] + + wrapper, should_deprecate = create_call_wrapper(handler, ListResourcesRequest) + + assert should_deprecate is False + + request = ListResourcesRequest(method="resources/list", params=None) + await wrapper(request) + + assert received_request is request + + # Test with ListToolsRequest + received_request = None + + async def handler2(req: ListToolsRequest) -> list[str]: + nonlocal received_request + received_request = req + return ["test"] + + wrapper2, should_deprecate2 = create_call_wrapper(handler2, ListToolsRequest) + + assert should_deprecate2 is False + + request2 = ListToolsRequest(method="tools/list", params=None) + await wrapper2(request2) + + assert received_request is request2 -class MyClass: - async def no_request_method(self): - """Instance method without request parameter""" - pass +def test_lambda_without_annotations() -> None: + """Test that lambda functions work correctly.""" + # Lambda without type hints - should be deprecated + handler = lambda: ["test"] # noqa: E731 - # noinspection PyMethodParameters - async def no_request_method_bad_self_name(bad): # pyright: ignore[reportSelfClsParameterName] - """Instance method without request parameter, but with bad self name""" - pass + _wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - async def request_method(self, request: types.ListPromptsRequest): - """Instance method with request parameter""" - pass + assert should_deprecate is True - # noinspection PyMethodParameters - async def request_method_bad_self_name(bad, request: types.ListPromptsRequest): # pyright: ignore[reportSelfClsParameterName] - """Instance method with request parameter, but with bad self name""" - pass - @classmethod - async def no_request_class_method(cls): - """Class method without request parameter""" - pass +def test_function_without_type_hints_resolvable() -> None: + """Test functions where type hints can't be resolved.""" - # noinspection PyMethodParameters - @classmethod - async def no_request_class_method_bad_cls_name(bad): # pyright: ignore[reportSelfClsParameterName] - """Class method without request parameter, but with bad cls name""" - pass + def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[reportMissingParameterType] + return ["test"] - @classmethod - async def request_class_method(cls, request: types.ListPromptsRequest): - """Class method with request parameter""" - pass + # Remove type hints to simulate resolution failure + handler.__annotations__ = {} - # noinspection PyMethodParameters - @classmethod - async def request_class_method_bad_cls_name(bad, request: types.ListPromptsRequest): # pyright: ignore[reportSelfClsParameterName] - """Class method with request parameter, but with bad cls name""" - pass + _, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) # pyright: ignore[reportUnknownArgumentType] - @staticmethod - async def no_request_static_method(): - """Static method without request parameter""" - pass + # Should default to deprecated when can't determine type + assert should_deprecate is True - @staticmethod - async def request_static_method(request: types.ListPromptsRequest): - """Static method with request parameter""" - pass - @staticmethod - async def request_static_method_bad_arg_name(self: types.ListPromptsRequest): # pyright: ignore[reportSelfClsParameterName] # noqa: PLW0211 - """Static method with request parameter, but the request argument is named self""" - pass +@pytest.mark.anyio +async def test_mixed_params_with_typed_request() -> None: + """Test: def foo(a: str, req: ListPromptsRequest, b: int = 5) - attempts to pass request.""" + async def handler(a: str, req: ListPromptsRequest, b: int = 5) -> list[str]: + return ["test"] -async def no_request_func(): - """Function without request parameter""" - pass + wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) + assert should_deprecate is False -async def request_func(request: types.ListPromptsRequest): - """Function with request parameter""" - pass - - -async def request_func_different_name(req: types.ListPromptsRequest): - """Function with request parameter but different arg name""" - pass - - -async def request_func_with_self(self: types.ListPromptsRequest): - """Function with parameter named 'self' (edge case)""" - pass - - -async def var_positional_func(*args: Any): - """Function with *args""" - pass - - -async def positional_with_var_positional_func(request: types.ListPromptsRequest, *args: Any): - """Function with request and *args""" - pass - - -async def var_keyword_func(**kwargs: Any): - """Function with **kwargs""" - pass - - -async def request_with_var_keyword_func(request: types.ListPromptsRequest, **kwargs: Any): - """Function with request and **kwargs""" - pass - - -async def request_with_default(request: types.ListPromptsRequest | None = None): - """Function with request parameter having default value""" - pass - - -async def keyword_only_with_defaults(*, request: types.ListPromptsRequest | None = None): - """Function with keyword-only request with default""" - pass - - -async def keyword_only_multiple_all_defaults(*, a: str = "test", b: int = 42): - """Function with multiple keyword-only params all with defaults""" - pass - - -async def mixed_positional_and_keyword(request: types.ListPromptsRequest, *, extra: str = "test"): - """Function with positional and keyword-only params""" - pass - - -@pytest.mark.parametrize( - "callable_obj,expected,description", - [ - # Regular functions - (no_request_func, False, "function without parameters"), - (request_func, True, "function with request parameter"), - (request_func_different_name, True, "function with request (different param name)"), - (request_func_with_self, True, "function with param named 'self'"), - # Instance methods - (MyClass().no_request_method, False, "instance method without request"), - (MyClass().no_request_method_bad_self_name, False, "instance method without request (bad self name)"), - (MyClass().request_method, True, "instance method with request"), - (MyClass().request_method_bad_self_name, True, "instance method with request (bad self name)"), - # Class methods - (MyClass.no_request_class_method, False, "class method without request"), - (MyClass.no_request_class_method_bad_cls_name, False, "class method without request (bad cls name)"), - (MyClass.request_class_method, True, "class method with request"), - (MyClass.request_class_method_bad_cls_name, True, "class method with request (bad cls name)"), - # Static methods - (MyClass.no_request_static_method, False, "static method without request"), - (MyClass.request_static_method, True, "static method with request"), - (MyClass.request_static_method_bad_arg_name, True, "static method with request (bad arg name)"), - # Variadic parameters - (var_positional_func, True, "function with *args"), - (positional_with_var_positional_func, True, "function with request and *args"), - (var_keyword_func, False, "function with **kwargs"), - (request_with_var_keyword_func, True, "function with request and **kwargs"), - # Edge cases - (request_with_default, True, "function with request having default value"), - # Keyword-only parameters - (keyword_only_with_defaults, False, "keyword-only with default (can call with no args)"), - (keyword_only_multiple_all_defaults, False, "multiple keyword-only all with defaults"), - (mixed_positional_and_keyword, True, "mixed positional and keyword-only params"), - ], - ids=lambda x: x if isinstance(x, str) else "", -) -def test_accepts_single_positional_arg(callable_obj: Callable[..., Any], expected: bool, description: str): - """Test that `accepts_single_positional_arg` correctly identifies functions that accept a single argument. + # Will fail at runtime due to missing 'a' + request = ListPromptsRequest(method="prompts/list", params=None) - `accepts_single_positional_arg` is currently only used in the case of - the lowlevel server code checking whether a handler accepts a request - argument, so the test cases reference a "request" param/arg. - - The function should return True if the callable can potentially accept a positional - request argument. Returns False if: - - No parameters at all - - Only keyword-only parameters that ALL have defaults (can call with no args) - - Only **kwargs parameter (can't accept positional arguments) - """ - assert accepts_single_positional_arg(callable_obj) == expected, f"Failed for {description}" + with pytest.raises(TypeError, match="missing 1 required positional argument: 'a'"): + await wrapper(request) diff --git a/tests/server/lowlevel/test_type_accepts_request.py b/tests/server/lowlevel/test_type_accepts_request.py deleted file mode 100644 index 7a080ce91..000000000 --- a/tests/server/lowlevel/test_type_accepts_request.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Unit tests for the type_accepts_request function.""" - -from typing import Any, TypeVar - -import pytest - -from mcp.server.lowlevel.func_inspection import type_accepts_request -from mcp.types import ListPromptsRequest, ListResourcesRequest, ListToolsRequest - - -@pytest.mark.parametrize( - "param_type,request_type,expected,description", - [ - # Exact type matches - (ListPromptsRequest, ListPromptsRequest, True, "exact type match"), - (ListToolsRequest, ListPromptsRequest, False, "different request type"), - (str, ListPromptsRequest, False, "string type"), - (int, ListPromptsRequest, False, "int type"), - (None, ListPromptsRequest, False, "None type"), - # Any type - (Any, ListPromptsRequest, True, "Any type accepts all"), - # Union types with request type - (ListPromptsRequest | None, ListPromptsRequest, True, "Optional request type"), - (str | ListPromptsRequest, ListPromptsRequest, True, "Union with request type (request second)"), - (ListPromptsRequest | str, ListPromptsRequest, True, "Union with request type (request first)"), - ( - ListPromptsRequest | ListToolsRequest, - ListPromptsRequest, - True, - "Union of multiple request types", - ), - # Union types without request type - (str | int, ListPromptsRequest, False, "Union of primitives"), - ( - ListToolsRequest | ListResourcesRequest, - ListPromptsRequest, - False, - "Union of different request types", - ), - (str | None, ListPromptsRequest, False, "Optional string"), - # Nested unions - ( - ListPromptsRequest | str | int, - ListPromptsRequest, - True, - "nested Union with request type", - ), - (str | int | bool, ListPromptsRequest, False, "nested Union without request type"), - # Generic types - (list[str], ListPromptsRequest, False, "generic list type"), - (list[ListPromptsRequest], ListPromptsRequest, False, "list of requests"), - ], -) -def test_type_accepts_request_simple( - param_type: Any, - request_type: type, - expected: bool, - description: str, -) -> None: - """Test type_accepts_request with simple type combinations.""" - assert type_accepts_request(param_type, request_type) is expected, f"Failed: {description}" - - -@pytest.mark.parametrize( - "typevar_factory,expected,description", - [ - # TypeVar with bounds - (lambda: TypeVar("BoundRequest", bound=ListPromptsRequest), True, "TypeVar bound to request type"), - (lambda: TypeVar("BoundString", bound=str), False, "TypeVar bound to different type"), - # TypeVar with constraints - ( - lambda: TypeVar("ConstrainedRequest", ListPromptsRequest, ListToolsRequest), - True, - "TypeVar constrained to include request type", - ), - (lambda: TypeVar("ConstrainedPrimitives", str, int), False, "TypeVar constrained to primitives"), - # TypeVar without bounds or constraints - (lambda: TypeVar("T"), False, "unbounded TypeVar"), - ], -) -def test_type_accepts_request_typevar( - typevar_factory: Any, - expected: bool, - description: str, -) -> None: - """Test type_accepts_request with TypeVar types.""" - param_type = typevar_factory() - assert type_accepts_request(param_type, ListPromptsRequest) is expected, f"Failed: {description}" From f212f8f6040e3f5f7a4d856dc616adf0810fecc4 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:57:22 +0100 Subject: [PATCH 11/11] fix: remove deprecation code --- src/mcp/server/lowlevel/func_inspection.py | 42 ++-- src/mcp/server/lowlevel/server.py | 20 +- .../lowlevel/test_deprecation_warnings.py | 196 ------------------ tests/server/lowlevel/test_func_inspection.py | 113 ++-------- 4 files changed, 37 insertions(+), 334 deletions(-) delete mode 100644 tests/server/lowlevel/test_deprecation_warnings.py diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py index 8bd69ec22..f5a745db2 100644 --- a/src/mcp/server/lowlevel/func_inspection.py +++ b/src/mcp/server/lowlevel/func_inspection.py @@ -1,72 +1,54 @@ import inspect -import warnings from collections.abc import Callable from typing import Any, TypeVar, get_type_hints - -def issue_deprecation_warning(func: Callable[..., Any], request_type: type) -> None: - """ - Issue a deprecation warning for handlers that don't use the new request parameter style. - """ - func_name = getattr(func, "__name__", str(func)) - warnings.warn( - f"Handler '{func_name}' should accept a '{request_type.__name__}' parameter. " - "Support for handlers without this parameter will be removed in a future version.", - DeprecationWarning, - stacklevel=4, - ) - - T = TypeVar("T") R = TypeVar("R") -def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> tuple[Callable[[T], R], bool]: +def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callable[[T], R]: """ Create a wrapper function that knows how to call func with the request object. - Returns a tuple of (wrapper_func, should_deprecate): - - wrapper_func: A function that takes the request and calls func appropriately - - should_deprecate: True if a deprecation warning should be issued + Returns a wrapper function that takes the request and calls func appropriately. The wrapper handles three calling patterns: 1. Positional-only parameter typed as request_type (no default): func(req) 2. Positional/keyword parameter typed as request_type (no default): func(**{param_name: req}) - 3. No request parameter or parameter with default (deprecated): func() + 3. No request parameter or parameter with default: func() """ try: sig = inspect.signature(func) type_hints = get_type_hints(func) except (ValueError, TypeError, NameError): - # Can't inspect signature or resolve type hints, assume no request parameter (deprecated) - return lambda _: func(), True + return lambda _: func() # Check for positional-only parameter typed as request_type for param_name, param in sig.parameters.items(): if param.kind == inspect.Parameter.POSITIONAL_ONLY: param_type = type_hints.get(param_name) if param_type == request_type: - # Check if it has a default - if so, treat as old style (deprecated) + # Check if it has a default - if so, treat as old style if param.default is not inspect.Parameter.empty: - return lambda _: func(), True + return lambda _: func() # Found positional-only parameter with correct type and no default - return lambda req: func(req), False + return lambda req: func(req) # Check for any positional/keyword parameter typed as request_type for param_name, param in sig.parameters.items(): if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY): param_type = type_hints.get(param_name) if param_type == request_type: - # Check if it has a default - if so, treat as old style (deprecated) + # Check if it has a default - if so, treat as old style if param.default is not inspect.Parameter.empty: - return lambda _: func(), True + return lambda _: func() # Found keyword parameter with correct type and no default # Need to capture param_name in closure properly def make_keyword_wrapper(name: str) -> Callable[[Any], Any]: return lambda req: func(**{name: req}) - return make_keyword_wrapper(param_name), False + return make_keyword_wrapper(param_name) - # No request parameter found - use old style (deprecated) - return lambda _: func(), True + # No request parameter found - use old style + return lambda _: func() diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index c3b978ecc..3448424bc 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -236,11 +236,7 @@ def decorator( ): logger.debug("Registering handler for PromptListRequest") - # Create wrapper that knows how to call func with the request - wrapper, _ = create_call_wrapper(func, types.ListPromptsRequest) - - # if should_deprecate: - # issue_deprecation_warning(func, types.ListPromptsRequest) + wrapper = create_call_wrapper(func, types.ListPromptsRequest) async def handler(req: types.ListPromptsRequest): result = await wrapper(req) @@ -278,12 +274,7 @@ def decorator( ): logger.debug("Registering handler for ListResourcesRequest") - # Create wrapper that knows how to call func with the request - wrapper, _ = create_call_wrapper(func, types.ListResourcesRequest) - - # TODO: Decide whether we want this sort of deprecation in a later PR - # if should_deprecate: - # issue_deprecation_warning(func, types.ListResourcesRequest) + wrapper = create_call_wrapper(func, types.ListResourcesRequest) async def handler(req: types.ListResourcesRequest): result = await wrapper(req) @@ -416,12 +407,7 @@ def decorator( ): logger.debug("Registering handler for ListToolsRequest") - # Create wrapper that knows how to call func with the request - wrapper, _ = create_call_wrapper(func, types.ListToolsRequest) - - # TODO: Decide whether we want this sort of deprecation in a later PR - # if should_deprecate: - # issue_deprecation_warning(func, types.ListToolsRequest) + wrapper = create_call_wrapper(func, types.ListToolsRequest) async def handler(req: types.ListToolsRequest): result = await wrapper(req) diff --git a/tests/server/lowlevel/test_deprecation_warnings.py b/tests/server/lowlevel/test_deprecation_warnings.py deleted file mode 100644 index f030b9459..000000000 --- a/tests/server/lowlevel/test_deprecation_warnings.py +++ /dev/null @@ -1,196 +0,0 @@ -# TODO: Decide whether we want deprecation warnings in another PR -# import warnings -# -# import pytest -# -# from mcp.server import Server -# from mcp.types import ( -# ListPromptsRequest, -# ListPromptsResult, -# ListResourcesRequest, -# ListResourcesResult, -# ListToolsRequest, -# ListToolsResult, -# PaginatedRequestParams, -# Prompt, -# Resource, -# ServerResult, -# Tool, -# ) -# -# -# @pytest.mark.anyio -# async def test_list_prompts_with_typed_request_no_warning() -> None: -# """Test that properly typed handlers don't trigger deprecation warnings.""" -# server = Server("test") -# -# with warnings.catch_warnings(record=True) as w: -# warnings.simplefilter("always") -# -# @server.list_prompts() -# async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: -# return ListPromptsResult(prompts=[]) -# -# # No deprecation warning should be issued -# assert len([warning for warning in w if issubclass(warning.category, DeprecationWarning)]) == 0 -# -# -# @pytest.mark.anyio -# async def test_list_prompts_without_params_triggers_warning() -> None: -# """Test that handlers without parameters trigger deprecation warnings.""" -# server = Server("test") -# -# with warnings.catch_warnings(record=True) as w: -# warnings.simplefilter("always") -# -# @server.list_prompts() -# async def handle_list_prompts() -> list[Prompt]: -# return [] -# -# # A deprecation warning should be issued -# deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] -# assert len(deprecation_warnings) == 1 -# assert "ListPromptsRequest" in str(deprecation_warnings[0].message) -# -# -# @pytest.mark.anyio -# async def test_list_prompts_with_untyped_param_triggers_warning() -> None: -# """Test that handlers with untyped parameters trigger deprecation warnings.""" -# server = Server("test") -# -# with warnings.catch_warnings(record=True) as w: -# warnings.simplefilter("always") -# -# @server.list_prompts() -# async def handle_list_prompts(request) -> ListPromptsResult: # type: ignore[no-untyped-def] -# return ListPromptsResult(prompts=[]) -# -# # A deprecation warning should be issued -# deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] -# assert len(deprecation_warnings) == 1 -# assert "ListPromptsRequest" in str(deprecation_warnings[0].message) -# -# -# @pytest.mark.anyio -# async def test_list_resources_with_typed_request_no_warning() -> None: -# """Test that properly typed resource handlers don't trigger warnings.""" -# server = Server("test") -# -# with warnings.catch_warnings(record=True) as w: -# warnings.simplefilter("always") -# -# @server.list_resources() -# async def handle_list_resources(request: ListResourcesRequest) -> ListResourcesResult: -# return ListResourcesResult(resources=[]) -# -# # No deprecation warning should be issued -# assert len([warning for warning in w if issubclass(warning.category, DeprecationWarning)]) == 0 -# -# -# @pytest.mark.anyio -# async def test_list_resources_without_params_triggers_warning() -> None: -# """Test that resource handlers without parameters trigger deprecation warnings.""" -# server = Server("test") -# -# with warnings.catch_warnings(record=True) as w: -# warnings.simplefilter("always") -# -# @server.list_resources() -# async def handle_list_resources() -> list[Resource]: -# return [] -# -# # A deprecation warning should be issued -# deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] -# assert len(deprecation_warnings) == 1 -# assert "ListResourcesRequest" in str(deprecation_warnings[0].message) -# -# -# @pytest.mark.anyio -# async def test_list_tools_with_typed_request_no_warning() -> None: -# """Test that properly typed tool handlers don't trigger warnings.""" -# server = Server("test") -# -# with warnings.catch_warnings(record=True) as w: -# warnings.simplefilter("always") -# -# @server.list_tools() -# async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: -# return ListToolsResult(tools=[]) -# -# # No deprecation warning should be issued -# assert len([warning for warning in w if issubclass(warning.category, DeprecationWarning)]) == 0 -# -# -# @pytest.mark.anyio -# async def test_list_tools_without_params_triggers_warning() -> None: -# """Test that tool handlers without parameters trigger deprecation warnings.""" -# server = Server("test") -# -# with warnings.catch_warnings(record=True) as w: -# warnings.simplefilter("always") -# -# @server.list_tools() -# async def handle_list_tools() -> list[Tool]: -# return [] -# -# # A deprecation warning should be issued -# deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] -# assert len(deprecation_warnings) == 1 -# assert "ListToolsRequest" in str(deprecation_warnings[0].message) -# -# -# @pytest.mark.anyio -# async def test_old_style_handler_still_works() -> None: -# """Test that old-style handlers still work (with deprecation warning).""" -# server = Server("test") -# -# with warnings.catch_warnings(record=True) as w: -# warnings.simplefilter("always") -# -# @server.list_prompts() -# async def handle_list_prompts() -> list[Prompt]: -# return [Prompt(name="test", description="Test prompt")] -# -# # Handler should be registered -# assert ListPromptsRequest in server.request_handlers -# -# # Deprecation warning should be issued -# deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] -# assert len(deprecation_warnings) == 1 -# -# # Handler should still work correctly -# handler = server.request_handlers[ListPromptsRequest] -# request = ListPromptsRequest(method="prompts/list", params=None) -# result = await handler(request) -# -# assert isinstance(result, ServerResult) -# assert isinstance(result.root, ListPromptsResult) -# assert len(result.root.prompts) == 1 -# assert result.root.prompts[0].name == "test" -# -# -# @pytest.mark.anyio -# async def test_new_style_handler_receives_pagination_params() -> None: -# """Test that new-style handlers receive pagination parameters correctly.""" -# server = Server("test") -# received_request: ListPromptsRequest | None = None -# -# @server.list_prompts() -# async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: -# nonlocal received_request -# received_request = request -# return ListPromptsResult(prompts=[], nextCursor="next-page") -# -# handler = server.request_handlers[ListPromptsRequest] -# -# # Test with cursor -# cursor_value = "test-cursor-123" -# request_with_cursor = ListPromptsRequest(method="prompts/list", -# params=PaginatedRequestParams(cursor=cursor_value)) -# result = await handler(request_with_cursor) -# -# assert received_request is not None -# assert received_request.params is not None -# assert received_request.params.cursor == cursor_value -# assert isinstance(result, ServerResult) -# assert result.root.nextCursor == "next-page" diff --git a/tests/server/lowlevel/test_func_inspection.py b/tests/server/lowlevel/test_func_inspection.py index f2e7233c0..556fede4a 100644 --- a/tests/server/lowlevel/test_func_inspection.py +++ b/tests/server/lowlevel/test_func_inspection.py @@ -16,7 +16,7 @@ @pytest.mark.anyio async def test_no_params_returns_deprecated_wrapper() -> None: - """Test: def foo() - should call without request and mark as deprecated.""" + """Test: def foo() - should call without request.""" called_without_request = False async def handler() -> list[str]: @@ -24,9 +24,7 @@ async def handler() -> list[str]: called_without_request = True return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - assert should_deprecate is True + wrapper = create_call_wrapper(handler, ListPromptsRequest) # Wrapper should call handler without passing request request = ListPromptsRequest(method="prompts/list", params=None) @@ -37,7 +35,7 @@ async def handler() -> list[str]: @pytest.mark.anyio async def test_param_with_default_returns_deprecated_wrapper() -> None: - """Test: def foo(thing: int = 1) - should call without request and mark as deprecated.""" + """Test: def foo(thing: int = 1) - should call without request.""" called_without_request = False async def handler(thing: int = 1) -> list[str]: @@ -45,9 +43,7 @@ async def handler(thing: int = 1) -> list[str]: called_without_request = True return [f"test-{thing}"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - assert should_deprecate is True + wrapper = create_call_wrapper(handler, ListPromptsRequest) # Wrapper should call handler without passing request (uses default value) request = ListPromptsRequest(method="prompts/list", params=None) @@ -66,9 +62,7 @@ async def handler(req: ListPromptsRequest) -> list[str]: received_request = req return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - assert should_deprecate is False + wrapper = create_call_wrapper(handler, ListPromptsRequest) # Wrapper should pass request to handler request = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor="test-cursor")) @@ -93,9 +87,7 @@ async def handler(req: ListPromptsRequest, thing: int = 1) -> list[str]: received_thing = thing return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - assert should_deprecate is False + wrapper = create_call_wrapper(handler, ListPromptsRequest) # Wrapper should pass request to handler request = ListPromptsRequest(method="prompts/list", params=None) @@ -107,7 +99,7 @@ async def handler(req: ListPromptsRequest, thing: int = 1) -> list[str]: @pytest.mark.anyio async def test_optional_typed_request_with_default_none_is_deprecated() -> None: - """Test: def foo(thing: int = 1, req: ListPromptsRequest | None = None) - deprecated.""" + """Test: def foo(thing: int = 1, req: ListPromptsRequest | None = None) - old style.""" called_without_request = False async def handler(thing: int = 1, req: ListPromptsRequest | None = None) -> list[str]: @@ -115,10 +107,7 @@ async def handler(thing: int = 1, req: ListPromptsRequest | None = None) -> list called_without_request = True return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - # Because req has a default value (None), it's treated as deprecated - assert should_deprecate is True + wrapper = create_call_wrapper(handler, ListPromptsRequest) # Wrapper should call handler without passing request request = ListPromptsRequest(method="prompts/list", params=None) @@ -129,7 +118,7 @@ async def handler(thing: int = 1, req: ListPromptsRequest | None = None) -> list @pytest.mark.anyio async def test_untyped_request_param_is_deprecated() -> None: - """Test: def foo(req) - should call without request and mark as deprecated.""" + """Test: def foo(req) - should call without request.""" called = False async def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[reportMissingParameterType] @@ -137,9 +126,7 @@ async def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[repor called = True return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) # pyright: ignore[reportUnknownArgumentType] - - assert should_deprecate is True + wrapper = create_call_wrapper(handler, ListPromptsRequest) # pyright: ignore[reportUnknownArgumentType] # Wrapper should call handler without passing request, which will fail because req is required request = ListPromptsRequest(method="prompts/list", params=None) @@ -150,14 +137,12 @@ async def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[repor @pytest.mark.anyio async def test_any_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: Any) - should call without request and mark as deprecated.""" + """Test: def foo(req: Any) - should call without request.""" async def handler(req: Any) -> list[str]: return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - assert should_deprecate is True + wrapper = create_call_wrapper(handler, ListPromptsRequest) # Wrapper should call handler without passing request, which will fail because req is required request = ListPromptsRequest(method="prompts/list", params=None) @@ -168,14 +153,12 @@ async def handler(req: Any) -> list[str]: @pytest.mark.anyio async def test_generic_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: Generic[T]) - should call without request and mark as deprecated.""" + """Test: def foo(req: Generic[T]) - should call without request.""" async def handler(req: Generic[T]) -> list[str]: # pyright: ignore[reportGeneralTypeIssues] return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - assert should_deprecate is True + wrapper = create_call_wrapper(handler, ListPromptsRequest) # Wrapper should call handler without passing request, which will fail because req is required request = ListPromptsRequest(method="prompts/list", params=None) @@ -186,14 +169,12 @@ async def handler(req: Generic[T]) -> list[str]: # pyright: ignore[reportGenera @pytest.mark.anyio async def test_wrong_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: str) - should call without request and mark as deprecated.""" + """Test: def foo(req: str) - should call without request.""" async def handler(req: str) -> list[str]: return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - assert should_deprecate is True + wrapper = create_call_wrapper(handler, ListPromptsRequest) # Wrapper should call handler without passing request, which will fail because req is required request = ListPromptsRequest(method="prompts/list", params=None) @@ -212,10 +193,7 @@ async def handler(thing: int, req: ListPromptsRequest) -> list[str]: received_request = req return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - # Not marked as deprecated because it has the correct type hint - assert should_deprecate is False + wrapper = create_call_wrapper(handler, ListPromptsRequest) # Wrapper will attempt to pass request, but it will fail at runtime # because 'thing' is required and has no default @@ -236,9 +214,7 @@ async def handler(req: ListPromptsRequest, /) -> list[str]: received_request = req return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - assert should_deprecate is False + wrapper = create_call_wrapper(handler, ListPromptsRequest) # Wrapper should pass request to handler request = ListPromptsRequest(method="prompts/list", params=None) @@ -247,18 +223,6 @@ async def handler(req: ListPromptsRequest, /) -> list[str]: assert received_request is request -def test_positional_only_param_with_default_is_deprecated() -> None: - """Test: def foo(req: ListPromptsRequest = None, /) - deprecated due to default value.""" - - async def handler(req: ListPromptsRequest = None, /) -> list[str]: # type: ignore[assignment] - return ["test"] - - _wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - # Has default value, so treated as deprecated - assert should_deprecate is True - - @pytest.mark.anyio async def test_keyword_only_param_with_correct_type() -> None: """Test: def foo(*, req: ListPromptsRequest) - should pass request through.""" @@ -269,9 +233,7 @@ async def handler(*, req: ListPromptsRequest) -> list[str]: received_request = req return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - assert should_deprecate is False + wrapper = create_call_wrapper(handler, ListPromptsRequest) # Wrapper should pass request to handler with keyword argument request = ListPromptsRequest(method="prompts/list", params=None) @@ -291,9 +253,7 @@ async def handler(req: ListResourcesRequest) -> list[str]: received_request = req return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListResourcesRequest) - - assert should_deprecate is False + wrapper = create_call_wrapper(handler, ListResourcesRequest) request = ListResourcesRequest(method="resources/list", params=None) await wrapper(request) @@ -308,9 +268,7 @@ async def handler2(req: ListToolsRequest) -> list[str]: received_request = req return ["test"] - wrapper2, should_deprecate2 = create_call_wrapper(handler2, ListToolsRequest) - - assert should_deprecate2 is False + wrapper2 = create_call_wrapper(handler2, ListToolsRequest) request2 = ListToolsRequest(method="tools/list", params=None) await wrapper2(request2) @@ -318,31 +276,6 @@ async def handler2(req: ListToolsRequest) -> list[str]: assert received_request is request2 -def test_lambda_without_annotations() -> None: - """Test that lambda functions work correctly.""" - # Lambda without type hints - should be deprecated - handler = lambda: ["test"] # noqa: E731 - - _wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - assert should_deprecate is True - - -def test_function_without_type_hints_resolvable() -> None: - """Test functions where type hints can't be resolved.""" - - def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[reportMissingParameterType] - return ["test"] - - # Remove type hints to simulate resolution failure - handler.__annotations__ = {} - - _, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) # pyright: ignore[reportUnknownArgumentType] - - # Should default to deprecated when can't determine type - assert should_deprecate is True - - @pytest.mark.anyio async def test_mixed_params_with_typed_request() -> None: """Test: def foo(a: str, req: ListPromptsRequest, b: int = 5) - attempts to pass request.""" @@ -350,9 +283,7 @@ async def test_mixed_params_with_typed_request() -> None: async def handler(a: str, req: ListPromptsRequest, b: int = 5) -> list[str]: return ["test"] - wrapper, should_deprecate = create_call_wrapper(handler, ListPromptsRequest) - - assert should_deprecate is False + wrapper = create_call_wrapper(handler, ListPromptsRequest) # Will fail at runtime due to missing 'a' request = ListPromptsRequest(method="prompts/list", params=None)