Skip to content

Commit a5a112a

Browse files
committed
Fix cancel handling to use anyio idioms and improve tests
- Use anyio.get_cancelled_exc_class() instead of asyncio.CancelledError for proper backend compatibility (supports both asyncio and trio) - Rewrote cancel handling tests to use integration testing pattern with create_connected_server_and_client_session instead of mocking - Tests now properly send CancelledNotification messages through real client-server communication - Added proper type annotations and fixed type issues in tests Addresses review feedback about using anyio idioms and integration testing. Reported-by: @ihrpr
1 parent 81bd191 commit a5a112a

File tree

2 files changed

+165
-118
lines changed

2 files changed

+165
-118
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ async def main():
6767

6868
from __future__ import annotations as _annotations
6969

70-
import asyncio
7170
import contextvars
7271
import json
7372
import logging
@@ -648,7 +647,7 @@ async def _handle_request(
648647
response = await handler(req)
649648
except McpError as err:
650649
response = err.error
651-
except asyncio.CancelledError:
650+
except anyio.get_cancelled_exc_class():
652651
logger.info(
653652
"Request %s cancelled - duplicate response suppressed",
654653
message.request_id,

tests/server/test_cancel_handling.py

Lines changed: 164 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,97 @@
11
"""Test that cancelled requests don't cause double responses."""
22

3-
import asyncio
4-
from unittest.mock import MagicMock
5-
3+
import anyio
64
import pytest
75

86
import mcp.types as types
97
from mcp.server.lowlevel.server import Server
10-
from mcp.types import PingRequest
11-
12-
13-
# Shared mock class
14-
class MockRequestResponder:
15-
def __init__(self):
16-
self.request_id = "test-123"
17-
self._responded = False
18-
self.request_meta = {}
19-
self.message_metadata = None
20-
21-
async def send(self, response):
22-
if self._responded:
23-
raise AssertionError(f"Request {self.request_id} already responded to")
24-
self._responded = True
25-
26-
async def respond(self, response):
27-
await self.send(response)
28-
29-
def cancel(self):
30-
"""Simulate the cancel() method sending an error response."""
31-
asyncio.create_task(self.send(types.ErrorData(code=-32800, message="Request cancelled")))
8+
from mcp.shared.exceptions import McpError
9+
from mcp.shared.memory import create_connected_server_and_client_session
10+
from mcp.types import (
11+
CallToolRequest,
12+
CallToolRequestParams,
13+
CallToolResult,
14+
CancelledNotification,
15+
CancelledNotificationParams,
16+
ClientNotification,
17+
ClientRequest,
18+
Tool,
19+
)
3220

3321

3422
@pytest.mark.anyio
3523
async def test_cancelled_request_no_double_response():
3624
"""Verify server handles cancelled requests without double response."""
3725

38-
# Create a server instance
26+
# Create server with a slow tool
3927
server = Server("test-server")
4028

41-
# Track if multiple responses are attempted
42-
response_count = 0
43-
44-
# Override the send method to track calls
45-
mock_message = MockRequestResponder()
46-
original_send = mock_message.send
47-
48-
async def tracked_send(response):
49-
nonlocal response_count
50-
response_count += 1
51-
await original_send(response)
52-
53-
mock_message.send = tracked_send
54-
55-
# Create a slow handler that will be cancelled
56-
async def slow_handler(req):
57-
await asyncio.sleep(10)
58-
return types.ServerResult(types.EmptyResult())
59-
60-
# Use PingRequest as it's a valid request type
61-
server.request_handlers[types.PingRequest] = slow_handler
62-
63-
# Create mock message and session
64-
mock_req = PingRequest(method="ping")
65-
mock_session = MagicMock()
66-
mock_context = None
67-
68-
# Start the request
69-
handle_task = asyncio.create_task(
70-
server._handle_request(mock_message, mock_req, mock_session, mock_context, raise_exceptions=False) # type: ignore
71-
)
72-
73-
# Give it time to start
74-
await asyncio.sleep(0.1)
75-
76-
# Simulate cancellation
77-
mock_message.cancel()
78-
handle_task.cancel()
79-
80-
# Wait for cancellation to propagate
81-
try:
82-
await handle_task
83-
except asyncio.CancelledError:
84-
pass
85-
86-
# Give time for any duplicate response attempts
87-
await asyncio.sleep(0.1)
88-
89-
# Should only have one response (from cancel())
90-
assert response_count == 1, f"Expected 1 response, got {response_count}"
29+
# Track when tool is called
30+
ev_tool_called = anyio.Event()
31+
request_id = None
32+
33+
@server.list_tools()
34+
async def handle_list_tools() -> list[Tool]:
35+
return [
36+
Tool(
37+
name="slow_tool",
38+
description="A slow tool for testing cancellation",
39+
inputSchema={},
40+
)
41+
]
42+
43+
@server.call_tool()
44+
async def handle_call_tool(name: str, arguments: dict | None) -> list:
45+
nonlocal request_id
46+
if name == "slow_tool":
47+
request_id = server.request_context.request_id
48+
ev_tool_called.set()
49+
await anyio.sleep(10) # Long running operation
50+
return [types.TextContent(type="text", text="Tool called")]
51+
raise ValueError(f"Unknown tool: {name}")
52+
53+
# Connect client to server
54+
async with create_connected_server_and_client_session(server) as client:
55+
# Start the slow tool call in a separate task
56+
async def make_request():
57+
try:
58+
await client.send_request(
59+
ClientRequest(
60+
CallToolRequest(
61+
method="tools/call",
62+
params=CallToolRequestParams(name="slow_tool", arguments={}),
63+
)
64+
),
65+
CallToolResult,
66+
)
67+
pytest.fail("Request should have been cancelled")
68+
except McpError as e:
69+
# Expected - request was cancelled
70+
assert e.error.code == 0 # Request cancelled error code
71+
72+
# Start the request
73+
request_task = anyio.create_task_group()
74+
async with request_task:
75+
request_task.start_soon(make_request)
76+
77+
# Wait for tool to start executing
78+
await ev_tool_called.wait()
79+
80+
# Send cancellation notification
81+
assert request_id is not None
82+
await client.send_notification(
83+
ClientNotification(
84+
CancelledNotification(
85+
method="notifications/cancelled",
86+
params=CancelledNotificationParams(
87+
requestId=request_id,
88+
reason="Test cancellation",
89+
),
90+
)
91+
)
92+
)
93+
94+
# The request should be cancelled and raise McpError
9195

9296

9397
@pytest.mark.anyio
@@ -96,43 +100,87 @@ async def test_server_remains_functional_after_cancel():
96100

97101
server = Server("test-server")
98102

99-
# Add handlers
100-
async def slow_handler(req):
101-
await asyncio.sleep(5)
102-
return types.ServerResult(types.EmptyResult())
103-
104-
async def fast_handler(req):
105-
return types.ServerResult(types.EmptyResult())
106-
107-
# Override ping handler for our test
108-
server.request_handlers[types.PingRequest] = slow_handler
109-
110-
# First request (will be cancelled)
111-
mock_message1 = MockRequestResponder()
112-
mock_req1 = PingRequest(method="ping")
113-
114-
handle_task = asyncio.create_task(
115-
server._handle_request(mock_message1, mock_req1, MagicMock(), None, raise_exceptions=False) # type: ignore
116-
)
117-
118-
await asyncio.sleep(0.1)
119-
mock_message1.cancel()
120-
handle_task.cancel()
121-
122-
try:
123-
await handle_task
124-
except asyncio.CancelledError:
125-
pass
126-
127-
# Change handler to fast one
128-
server.request_handlers[types.PingRequest] = fast_handler
129-
130-
# Second request (should work normally)
131-
mock_message2 = MockRequestResponder()
132-
mock_req2 = PingRequest(method="ping")
133-
134-
# This should complete successfully
135-
await server._handle_request(mock_message2, mock_req2, MagicMock(), None, raise_exceptions=False) # type: ignore
136-
137-
# Server handled the second request successfully
138-
assert mock_message2._responded
103+
# Track tool calls
104+
call_count = 0
105+
ev_first_call = anyio.Event()
106+
first_request_id = None
107+
108+
@server.list_tools()
109+
async def handle_list_tools() -> list[Tool]:
110+
return [
111+
Tool(
112+
name="test_tool",
113+
description="Tool for testing",
114+
inputSchema={},
115+
)
116+
]
117+
118+
@server.call_tool()
119+
async def handle_call_tool(name: str, arguments: dict | None) -> list:
120+
nonlocal call_count, first_request_id
121+
if name == "test_tool":
122+
call_count += 1
123+
if call_count == 1:
124+
first_request_id = server.request_context.request_id
125+
ev_first_call.set()
126+
await anyio.sleep(5) # First call is slow
127+
return [types.TextContent(type="text", text=f"Call number: {call_count}")]
128+
raise ValueError(f"Unknown tool: {name}")
129+
130+
async with create_connected_server_and_client_session(server) as client:
131+
# First request (will be cancelled)
132+
async def first_request():
133+
try:
134+
await client.send_request(
135+
ClientRequest(
136+
CallToolRequest(
137+
method="tools/call",
138+
params=CallToolRequestParams(name="test_tool", arguments={}),
139+
)
140+
),
141+
CallToolResult,
142+
)
143+
pytest.fail("First request should have been cancelled")
144+
except McpError:
145+
pass # Expected
146+
147+
# Start first request
148+
async with anyio.create_task_group() as tg:
149+
tg.start_soon(first_request)
150+
151+
# Wait for it to start
152+
await ev_first_call.wait()
153+
154+
# Cancel it
155+
assert first_request_id is not None
156+
await client.send_notification(
157+
ClientNotification(
158+
CancelledNotification(
159+
method="notifications/cancelled",
160+
params=CancelledNotificationParams(
161+
requestId=first_request_id,
162+
reason="Testing server recovery",
163+
),
164+
)
165+
)
166+
)
167+
168+
# Second request (should work normally)
169+
result = await client.send_request(
170+
ClientRequest(
171+
CallToolRequest(
172+
method="tools/call",
173+
params=CallToolRequestParams(name="test_tool", arguments={}),
174+
)
175+
),
176+
CallToolResult,
177+
)
178+
179+
# Verify second request completed successfully
180+
assert len(result.content) == 1
181+
# Type narrowing for pyright
182+
content = result.content[0]
183+
assert content.type == "text"
184+
assert isinstance(content, types.TextContent)
185+
assert content.text == "Call number: 2"
186+
assert call_count == 2

0 commit comments

Comments
 (0)