@@ -213,3 +213,129 @@ async def mock_client():
213213
214214 assert received_initialized
215215 assert received_protocol_version == "2024-11-05"
216+
217+
218+ @pytest .mark .anyio
219+ async def test_ping_request_before_initialization ():
220+ """Test that ping requests are allowed before initialization is complete."""
221+ server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
222+ client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
223+
224+ ping_response_received = False
225+ ping_response_id = None
226+
227+ async def run_server ():
228+ async with ServerSession (
229+ client_to_server_receive ,
230+ server_to_client_send ,
231+ InitializationOptions (
232+ server_name = "mcp" ,
233+ server_version = "0.1.0" ,
234+ capabilities = ServerCapabilities (),
235+ ),
236+ ) as server_session :
237+ async for message in server_session .incoming_messages :
238+ if isinstance (message , Exception ):
239+ raise message
240+
241+ # We should receive a ping request before initialization
242+ if isinstance (message , RequestResponder ) and isinstance (message .request .root , types .PingRequest ):
243+ # Respond to the ping
244+ with message :
245+ await message .respond (types .ServerResult (types .EmptyResult ()))
246+ return
247+
248+ async def mock_client ():
249+ nonlocal ping_response_received , ping_response_id
250+
251+ # Send ping request before any initialization
252+ await client_to_server_send .send (
253+ SessionMessage (
254+ types .JSONRPCMessage (
255+ types .JSONRPCRequest (
256+ jsonrpc = "2.0" ,
257+ id = 42 ,
258+ method = "ping" ,
259+ )
260+ )
261+ )
262+ )
263+
264+ # Wait for the ping response
265+ ping_response_message = await server_to_client_receive .receive ()
266+ assert isinstance (ping_response_message .message .root , types .JSONRPCResponse )
267+
268+ ping_response_received = True
269+ ping_response_id = ping_response_message .message .root .id
270+
271+ async with (
272+ client_to_server_send ,
273+ client_to_server_receive ,
274+ server_to_client_send ,
275+ server_to_client_receive ,
276+ anyio .create_task_group () as tg ,
277+ ):
278+ tg .start_soon (run_server )
279+ tg .start_soon (mock_client )
280+
281+ assert ping_response_received
282+ assert ping_response_id == 42
283+
284+
285+ @pytest .mark .anyio
286+ async def test_other_requests_blocked_before_initialization ():
287+ """Test that non-ping requests are still blocked before initialization."""
288+ server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [SessionMessage ](1 )
289+ client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage | Exception ](1 )
290+
291+ error_response_received = False
292+ error_code = None
293+
294+ async def run_server ():
295+ async with ServerSession (
296+ client_to_server_receive ,
297+ server_to_client_send ,
298+ InitializationOptions (
299+ server_name = "mcp" ,
300+ server_version = "0.1.0" ,
301+ capabilities = ServerCapabilities (),
302+ ),
303+ ):
304+ # Server should handle the request and send an error response
305+ # No need to process incoming_messages since the error is handled automatically
306+ await anyio .sleep (0.1 ) # Give time for the request to be processed
307+
308+ async def mock_client ():
309+ nonlocal error_response_received , error_code
310+
311+ # Try to send a non-ping request before initialization
312+ await client_to_server_send .send (
313+ SessionMessage (
314+ types .JSONRPCMessage (
315+ types .JSONRPCRequest (
316+ jsonrpc = "2.0" ,
317+ id = 1 ,
318+ method = "prompts/list" ,
319+ )
320+ )
321+ )
322+ )
323+
324+ # Wait for the error response
325+ error_message = await server_to_client_receive .receive ()
326+ if isinstance (error_message .message .root , types .JSONRPCError ):
327+ error_response_received = True
328+ error_code = error_message .message .root .error .code
329+
330+ async with (
331+ client_to_server_send ,
332+ client_to_server_receive ,
333+ server_to_client_send ,
334+ server_to_client_receive ,
335+ anyio .create_task_group () as tg ,
336+ ):
337+ tg .start_soon (run_server )
338+ tg .start_soon (mock_client )
339+
340+ assert error_response_received
341+ assert error_code == types .INVALID_PARAMS
0 commit comments