diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 891eff0a6..a07a70270 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -165,6 +165,12 @@ async def initialize(self) -> types.InitializeResult: ) ), types.InitializeResult, + # TODO should set a request_read_timeout_seconds as per + # guidance from BaseSession.send_request not obvious + # what subsequent process should be, refer the following + # specification for more details + # https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation + cancellable=False, ) if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: @@ -287,6 +293,7 @@ async def call_tool( arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, progress_callback: ProgressFnT | None = None, + cancellable: bool = True, ) -> types.CallToolResult: """Send a tools/call request with optional progress callback support.""" @@ -303,6 +310,7 @@ async def call_tool( types.CallToolResult, request_read_timeout_seconds=read_timeout_seconds, progress_callback=progress_callback, + cancellable=cancellable, ) if not result.isError: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index faad95aca..117c8f7e7 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -89,7 +89,7 @@ async def main(): from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.session import RequestResponder +from mcp.shared.session import RequestId, RequestResponder logger = logging.getLogger(__name__) @@ -520,6 +520,20 @@ async def handler(req: types.ProgressNotification): return decorator + def cancel_notification(self): + def decorator( + func: Callable[[RequestId, str | None], Awaitable[None]], + ): + logger.debug("Registering handler for CancelledNotification") + + async def handler(req: types.CancelledNotification): + await func(req.params.requestId, req.params.reason) + + self.notification_handlers[types.CancelledNotification] = handler + return func + + return decorator + def completion(self): """Provides completions for prompts and resource templates""" diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 6536272d9..0adfe2138 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -16,7 +16,9 @@ from mcp.types import ( CONNECTION_CLOSED, INVALID_PARAMS, + REQUEST_CANCELLED, CancelledNotification, + CancelledNotificationParams, ClientNotification, ClientRequest, ClientResult, @@ -224,12 +226,25 @@ async def send_request( request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, progress_callback: ProgressFnT | None = None, + cancellable: bool = True, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the response contains an error. If a request read timeout is provided, it will take precedence over the session read timeout. + If cancellable is set to False then the request will wait + request_read_timeout_seconds to complete and ignore any attempt to + cancel via the anyio.CancelScope within which this method was called. + + If cancellable is set to True (default) if the anyio.CancelScope within + which this method was called is cancelled it will generate a + CancelationNotfication and send this to the server which should then abort + the task however the server is is not guaranteed to honour this request. + + For further information on the CancelNotification flow refer to + https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation + Do not use this method to emit notifications! Use send_notification() instead. """ @@ -267,20 +282,32 @@ async def send_request( elif self._session_read_timeout_seconds is not None: timeout = self._session_read_timeout_seconds.total_seconds() - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - raise McpError( - ErrorData( - code=httpx.codes.REQUEST_TIMEOUT, - message=( - f"Timed out while waiting for response to " - f"{request.__class__.__name__}. Waited " - f"{timeout} seconds." - ), + with anyio.CancelScope(shield=not cancellable): + try: + with anyio.fail_after(timeout) as scope: + response_or_error = await response_stream_reader.receive() + + if scope.cancel_called: + notification = CancelledNotification( + method="notifications/cancelled", + params=CancelledNotificationParams(requestId=request_id, reason="cancelled"), + ) + await self._send_notification( # type: ignore + notification, request_id + ) + raise McpError(ErrorData(code=REQUEST_CANCELLED, message="Request cancelled")) + + except TimeoutError: + raise McpError( + ErrorData( + code=httpx.codes.REQUEST_TIMEOUT, + message=( + f"Timed out while waiting for response to " + f"{request.__class__.__name__}. Waited " + f"{timeout} seconds." + ), + ) ) - ) if isinstance(response_or_error, JSONRPCError): raise McpError(response_or_error.error) diff --git a/src/mcp/types.py b/src/mcp/types.py index 4a9c2bf1a..51f412e4f 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -156,6 +156,7 @@ class JSONRPCResponse(BaseModel): METHOD_NOT_FOUND = -32601 INVALID_PARAMS = -32602 INTERNAL_ERROR = -32603 +REQUEST_CANCELLED = -32604 class ErrorData(BaseModel): diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 864e0d1b4..948d4391f 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,4 +1,5 @@ from collections.abc import AsyncGenerator +from datetime import timedelta import anyio import pytest @@ -8,14 +9,9 @@ from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import ( - create_client_server_memory_streams, create_connected_server_and_client_session, ) from mcp.types import ( - CancelledNotification, - CancelledNotificationParams, - ClientNotification, - ClientRequest, EmptyResult, ) @@ -49,11 +45,11 @@ async def test_in_flight_requests_cleared_after_completion( @pytest.mark.anyio async def test_request_cancellation(): """Test that requests can be cancelled while in-flight.""" - # The tool is already registered in the fixture ev_tool_called = anyio.Event() + ev_tool_cancelled = anyio.Event() ev_cancelled = anyio.Event() - request_id = None + ev_cancel_notified = anyio.Event() # Start the request in a separate task so we can cancel it def make_server() -> Server: @@ -62,14 +58,24 @@ def make_server() -> Server: # Register the tool handler @server.call_tool() async def handle_call_tool(name: str, arguments: dict | None) -> list: - nonlocal request_id, ev_tool_called + nonlocal ev_tool_called, ev_tool_cancelled if name == "slow_tool": - request_id = server.request_context.request_id ev_tool_called.set() - await anyio.sleep(10) # Long enough to ensure we can cancel - return [] + with anyio.CancelScope(): + try: + await anyio.sleep(10) # Long enough to ensure we can cancel + return [] + except anyio.get_cancelled_exc_class() as err: + ev_tool_cancelled.set() + raise err + raise ValueError(f"Unknown tool: {name}") + @server.cancel_notification() + async def handle_cancel(requestId: str | int, reason: str | None): + nonlocal ev_cancel_notified + ev_cancel_notified.set() + # Register the tool so it shows up in list_tools @server.list_tools() async def handle_list_tools() -> list[types.Tool]: @@ -83,18 +89,10 @@ async def handle_list_tools() -> list[types.Tool]: return server - async def make_request(client_session): + async def make_request(client_session: ClientSession): nonlocal ev_cancelled try: - await client_session.send_request( - ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams(name="slow_tool", arguments={}), - ) - ), - types.CallToolResult, - ) + await client_session.call_tool("slow_tool") pytest.fail("Request should have been cancelled") except McpError as e: # Expected - request was cancelled @@ -109,71 +107,85 @@ async def make_request(client_session): with anyio.fail_after(1): # Timeout after 1 second await ev_tool_called.wait() - # Send cancellation notification - assert request_id is not None - await client_session.send_notification( - ClientNotification( - CancelledNotification( - method="notifications/cancelled", - params=CancelledNotificationParams(requestId=request_id), - ) - ) - ) + # Cancel the task via task group + tg.cancel_scope.cancel() # Give cancellation time to process with anyio.fail_after(1): await ev_cancelled.wait() + # Check server cancel notification received + with anyio.fail_after(1): + await ev_cancel_notified.wait() + + # Give cancellation time to process on server + with anyio.fail_after(1): + await ev_tool_cancelled.wait() + @pytest.mark.anyio -async def test_connection_closed(): - """ - Test that pending requests are cancelled when the connection is closed remotely. - """ - - ev_closed = anyio.Event() - ev_response = anyio.Event() - - async with create_client_server_memory_streams() as ( - client_streams, - server_streams, - ): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def make_request(client_session): - """Send a request in a separate task""" - nonlocal ev_response - try: - # any request will do - await client_session.initialize() - pytest.fail("Request should have errored") - except McpError as e: - # Expected - request errored - assert "Connection closed" in str(e) - ev_response.set() - - async def mock_server(): - """Wait for a request, then close the connection""" - nonlocal ev_closed - # Wait for a request - await server_read.receive() - # Close the connection, as if the server exited - server_write.close() - server_read.close() - ev_closed.set() - - async with ( - anyio.create_task_group() as tg, - ClientSession( - read_stream=client_read, - write_stream=client_write, - ) as client_session, - ): +async def test_request_cancellation_uncancellable(): + """Test that asserts a call with cancellable=False is not cancelled on + server when cancel scope on client is set.""" + + ev_tool_called = anyio.Event() + ev_tool_commplete = anyio.Event() + ev_cancelled = anyio.Event() + + # Start the request in a separate task so we can cancel it + def make_server() -> Server: + server = Server(name="TestSessionServer") + + # Register the tool handler + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict | None) -> list: + nonlocal ev_tool_called, ev_tool_commplete + if name == "slow_tool": + ev_tool_called.set() + with anyio.CancelScope(): + with anyio.fail_after(10): # Long enough to ensure we can cancel + await ev_cancelled.wait() + ev_tool_commplete.set() + return [] + + raise ValueError(f"Unknown tool: {name}") + + # Register the tool so it shows up in list_tools + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="slow_tool", + description="A slow tool that takes 10 seconds to complete", + inputSchema={}, + ) + ] + + return server + + async def make_request(client_session: ClientSession): + nonlocal ev_cancelled + try: + await client_session.call_tool( + "slow_tool", + cancellable=False, + read_timeout_seconds=timedelta(seconds=10), + ) + except McpError: + pytest.fail("Request should not have been cancelled") + + async with create_connected_server_and_client_session(make_server()) as client_session: + async with anyio.create_task_group() as tg: tg.start_soon(make_request, client_session) - tg.start_soon(mock_server) + # Wait for the request to be in-flight + with anyio.fail_after(1): # Timeout after 1 second + await ev_tool_called.wait() + + # Cancel the task via task group + tg.cancel_scope.cancel() + ev_cancelled.set() + + # Check server completed regardless with anyio.fail_after(1): - await ev_closed.wait() - with anyio.fail_after(1): - await ev_response.wait() + await ev_tool_commplete.wait()