1
1
"""Test that cancelled requests don't cause double responses."""
2
2
3
- import asyncio
4
- from unittest .mock import MagicMock
5
-
3
+ import anyio
6
4
import pytest
7
5
8
6
import mcp .types as types
9
7
from mcp .server .lowlevel .server import Server
10
- from mcp .types import PingRequest
11
-
12
-
13
- # Shared mock class
14
- class MockRequestResponder :
15
- def __init__ (self ):
16
- self .request_id = "test-123"
17
- self ._responded = False
18
- self .request_meta = {}
19
- self .message_metadata = None
20
-
21
- async def send (self , response ):
22
- if self ._responded :
23
- raise AssertionError (f"Request { self .request_id } already responded to" )
24
- self ._responded = True
25
-
26
- async def respond (self , response ):
27
- await self .send (response )
28
-
29
- def cancel (self ):
30
- """Simulate the cancel() method sending an error response."""
31
- asyncio .create_task (self .send (types .ErrorData (code = - 32800 , message = "Request cancelled" )))
8
+ from mcp .shared .exceptions import McpError
9
+ from mcp .shared .memory import create_connected_server_and_client_session
10
+ from mcp .types import (
11
+ CallToolRequest ,
12
+ CallToolRequestParams ,
13
+ CallToolResult ,
14
+ CancelledNotification ,
15
+ CancelledNotificationParams ,
16
+ ClientNotification ,
17
+ ClientRequest ,
18
+ Tool ,
19
+ )
32
20
33
21
34
22
@pytest .mark .anyio
35
23
async def test_cancelled_request_no_double_response ():
36
24
"""Verify server handles cancelled requests without double response."""
37
25
38
- # Create a server instance
26
+ # Create server with a slow tool
39
27
server = Server ("test-server" )
40
28
41
- # Track if multiple responses are attempted
42
- response_count = 0
43
-
44
- # Override the send method to track calls
45
- mock_message = MockRequestResponder ()
46
- original_send = mock_message .send
47
-
48
- async def tracked_send (response ):
49
- nonlocal response_count
50
- response_count += 1
51
- await original_send (response )
52
-
53
- mock_message .send = tracked_send
54
-
55
- # Create a slow handler that will be cancelled
56
- async def slow_handler (req ):
57
- await asyncio .sleep (10 )
58
- return types .ServerResult (types .EmptyResult ())
59
-
60
- # Use PingRequest as it's a valid request type
61
- server .request_handlers [types .PingRequest ] = slow_handler
62
-
63
- # Create mock message and session
64
- mock_req = PingRequest (method = "ping" )
65
- mock_session = MagicMock ()
66
- mock_context = None
67
-
68
- # Start the request
69
- handle_task = asyncio .create_task (
70
- server ._handle_request (mock_message , mock_req , mock_session , mock_context , raise_exceptions = False ) # type: ignore
71
- )
72
-
73
- # Give it time to start
74
- await asyncio .sleep (0.1 )
75
-
76
- # Simulate cancellation
77
- mock_message .cancel ()
78
- handle_task .cancel ()
79
-
80
- # Wait for cancellation to propagate
81
- try :
82
- await handle_task
83
- except asyncio .CancelledError :
84
- pass
85
-
86
- # Give time for any duplicate response attempts
87
- await asyncio .sleep (0.1 )
88
-
89
- # Should only have one response (from cancel())
90
- assert response_count == 1 , f"Expected 1 response, got { response_count } "
29
+ # Track when tool is called
30
+ ev_tool_called = anyio .Event ()
31
+ request_id = None
32
+
33
+ @server .list_tools ()
34
+ async def handle_list_tools () -> list [Tool ]:
35
+ return [
36
+ Tool (
37
+ name = "slow_tool" ,
38
+ description = "A slow tool for testing cancellation" ,
39
+ inputSchema = {},
40
+ )
41
+ ]
42
+
43
+ @server .call_tool ()
44
+ async def handle_call_tool (name : str , arguments : dict | None ) -> list :
45
+ nonlocal request_id
46
+ if name == "slow_tool" :
47
+ request_id = server .request_context .request_id
48
+ ev_tool_called .set ()
49
+ await anyio .sleep (10 ) # Long running operation
50
+ return [types .TextContent (type = "text" , text = "Tool called" )]
51
+ raise ValueError (f"Unknown tool: { name } " )
52
+
53
+ # Connect client to server
54
+ async with create_connected_server_and_client_session (server ) as client :
55
+ # Start the slow tool call in a separate task
56
+ async def make_request ():
57
+ try :
58
+ await client .send_request (
59
+ ClientRequest (
60
+ CallToolRequest (
61
+ method = "tools/call" ,
62
+ params = CallToolRequestParams (name = "slow_tool" , arguments = {}),
63
+ )
64
+ ),
65
+ CallToolResult ,
66
+ )
67
+ pytest .fail ("Request should have been cancelled" )
68
+ except McpError as e :
69
+ # Expected - request was cancelled
70
+ assert e .error .code == 0 # Request cancelled error code
71
+
72
+ # Start the request
73
+ request_task = anyio .create_task_group ()
74
+ async with request_task :
75
+ request_task .start_soon (make_request )
76
+
77
+ # Wait for tool to start executing
78
+ await ev_tool_called .wait ()
79
+
80
+ # Send cancellation notification
81
+ assert request_id is not None
82
+ await client .send_notification (
83
+ ClientNotification (
84
+ CancelledNotification (
85
+ method = "notifications/cancelled" ,
86
+ params = CancelledNotificationParams (
87
+ requestId = request_id ,
88
+ reason = "Test cancellation" ,
89
+ ),
90
+ )
91
+ )
92
+ )
93
+
94
+ # The request should be cancelled and raise McpError
91
95
92
96
93
97
@pytest .mark .anyio
@@ -96,43 +100,87 @@ async def test_server_remains_functional_after_cancel():
96
100
97
101
server = Server ("test-server" )
98
102
99
- # Add handlers
100
- async def slow_handler (req ):
101
- await asyncio .sleep (5 )
102
- return types .ServerResult (types .EmptyResult ())
103
-
104
- async def fast_handler (req ):
105
- return types .ServerResult (types .EmptyResult ())
106
-
107
- # Override ping handler for our test
108
- server .request_handlers [types .PingRequest ] = slow_handler
109
-
110
- # First request (will be cancelled)
111
- mock_message1 = MockRequestResponder ()
112
- mock_req1 = PingRequest (method = "ping" )
113
-
114
- handle_task = asyncio .create_task (
115
- server ._handle_request (mock_message1 , mock_req1 , MagicMock (), None , raise_exceptions = False ) # type: ignore
116
- )
117
-
118
- await asyncio .sleep (0.1 )
119
- mock_message1 .cancel ()
120
- handle_task .cancel ()
121
-
122
- try :
123
- await handle_task
124
- except asyncio .CancelledError :
125
- pass
126
-
127
- # Change handler to fast one
128
- server .request_handlers [types .PingRequest ] = fast_handler
129
-
130
- # Second request (should work normally)
131
- mock_message2 = MockRequestResponder ()
132
- mock_req2 = PingRequest (method = "ping" )
133
-
134
- # This should complete successfully
135
- await server ._handle_request (mock_message2 , mock_req2 , MagicMock (), None , raise_exceptions = False ) # type: ignore
136
-
137
- # Server handled the second request successfully
138
- assert mock_message2 ._responded
103
+ # Track tool calls
104
+ call_count = 0
105
+ ev_first_call = anyio .Event ()
106
+ first_request_id = None
107
+
108
+ @server .list_tools ()
109
+ async def handle_list_tools () -> list [Tool ]:
110
+ return [
111
+ Tool (
112
+ name = "test_tool" ,
113
+ description = "Tool for testing" ,
114
+ inputSchema = {},
115
+ )
116
+ ]
117
+
118
+ @server .call_tool ()
119
+ async def handle_call_tool (name : str , arguments : dict | None ) -> list :
120
+ nonlocal call_count , first_request_id
121
+ if name == "test_tool" :
122
+ call_count += 1
123
+ if call_count == 1 :
124
+ first_request_id = server .request_context .request_id
125
+ ev_first_call .set ()
126
+ await anyio .sleep (5 ) # First call is slow
127
+ return [types .TextContent (type = "text" , text = f"Call number: { call_count } " )]
128
+ raise ValueError (f"Unknown tool: { name } " )
129
+
130
+ async with create_connected_server_and_client_session (server ) as client :
131
+ # First request (will be cancelled)
132
+ async def first_request ():
133
+ try :
134
+ await client .send_request (
135
+ ClientRequest (
136
+ CallToolRequest (
137
+ method = "tools/call" ,
138
+ params = CallToolRequestParams (name = "test_tool" , arguments = {}),
139
+ )
140
+ ),
141
+ CallToolResult ,
142
+ )
143
+ pytest .fail ("First request should have been cancelled" )
144
+ except McpError :
145
+ pass # Expected
146
+
147
+ # Start first request
148
+ async with anyio .create_task_group () as tg :
149
+ tg .start_soon (first_request )
150
+
151
+ # Wait for it to start
152
+ await ev_first_call .wait ()
153
+
154
+ # Cancel it
155
+ assert first_request_id is not None
156
+ await client .send_notification (
157
+ ClientNotification (
158
+ CancelledNotification (
159
+ method = "notifications/cancelled" ,
160
+ params = CancelledNotificationParams (
161
+ requestId = first_request_id ,
162
+ reason = "Testing server recovery" ,
163
+ ),
164
+ )
165
+ )
166
+ )
167
+
168
+ # Second request (should work normally)
169
+ result = await client .send_request (
170
+ ClientRequest (
171
+ CallToolRequest (
172
+ method = "tools/call" ,
173
+ params = CallToolRequestParams (name = "test_tool" , arguments = {}),
174
+ )
175
+ ),
176
+ CallToolResult ,
177
+ )
178
+
179
+ # Verify second request completed successfully
180
+ assert len (result .content ) == 1
181
+ # Type narrowing for pyright
182
+ content = result .content [0 ]
183
+ assert content .type == "text"
184
+ assert isinstance (content , types .TextContent )
185
+ assert content .text == "Call number: 2"
186
+ assert call_count == 2
0 commit comments