55
66import mcp
77from mcp import types
8- from mcp .client .session_group import ClientSessionGroup
8+ from mcp .client .session_group import (
9+ ClientSessionGroup ,
10+ SseServerParameters ,
11+ StreamableHttpParameters ,
12+ )
913from mcp .client .stdio import StdioServerParameters
1014from mcp .shared .exceptions import McpError
1115
@@ -19,12 +23,6 @@ def mock_exit_stack():
1923 return mock .MagicMock (spec = contextlib .AsyncExitStack )
2024
2125
22- @pytest .fixture
23- def mock_server_params (): # No mocker needed here
24- """Fixture for mocked StdioServerParameters."""
25- return mock .Mock (spec = StdioServerParameters )
26-
27-
2826@pytest .mark .anyio
2927class TestClientSessionGroup :
3028 def test_init (self ):
@@ -79,7 +77,7 @@ async def test_call_tool(self):
7977 {"name" : "value1" , "args" : {}},
8078 )
8179
82- async def test_connect_to_server (self , mock_exit_stack , mock_server_params ):
80+ async def test_connect_to_server (self , mock_exit_stack ):
8381 """Test connecting to a server and aggregating components."""
8482 # --- Mock Dependencies ---
8583 mock_server_info = mock .Mock (spec = types .Implementation )
@@ -102,7 +100,7 @@ async def test_connect_to_server(self, mock_exit_stack, mock_server_params):
102100 with mock .patch .object (
103101 group , "_establish_session" , return_value = (mock_server_info , mock_session )
104102 ):
105- await group .connect_to_server (mock_server_params )
103+ await group .connect_to_server (StdioServerParameters ( command = "test" ) )
106104
107105 # --- Assertions ---
108106 assert mock_session in group ._sessions
@@ -120,9 +118,7 @@ async def test_connect_to_server(self, mock_exit_stack, mock_server_params):
120118 mock_session .list_resources .assert_awaited_once ()
121119 mock_session .list_prompts .assert_awaited_once ()
122120
123- async def test_connect_to_server_with_name_hook (
124- self , mock_exit_stack , mock_server_params
125- ):
121+ async def test_connect_to_server_with_name_hook (self , mock_exit_stack ):
126122 """Test connecting with a component name hook."""
127123 # --- Mock Dependencies ---
128124 mock_server_info = mock .Mock (spec = types .Implementation )
@@ -145,7 +141,7 @@ def name_hook(name: str, server_info: types.Implementation) -> str:
145141 with mock .patch .object (
146142 group , "_establish_session" , return_value = (mock_server_info , mock_session )
147143 ):
148- await group .connect_to_server (mock_server_params )
144+ await group .connect_to_server (StdioServerParameters ( command = "test" ) )
149145
150146 # --- Assertions ---
151147 assert mock_session in group ._sessions
@@ -218,9 +214,7 @@ def test_disconnect_from_server(self): # No mock arguments needed
218214 assert "res1" not in group ._resources
219215 assert "prm1" not in group ._prompts
220216
221- async def test_connect_to_server_duplicate_tool_raises_error (
222- self , mock_exit_stack , mock_server_params
223- ):
217+ async def test_connect_to_server_duplicate_tool_raises_error (self , mock_exit_stack ):
224218 """Test McpError raised when connecting a server with a dup name."""
225219 # --- Setup Pre-existing State ---
226220 group = ClientSessionGroup (exit_stack = mock_exit_stack )
@@ -255,7 +249,7 @@ async def test_connect_to_server_duplicate_tool_raises_error(
255249 "_establish_session" ,
256250 return_value = (mock_server_info_new , mock_session_new ),
257251 ):
258- await group .connect_to_server (mock_server_params )
252+ await group .connect_to_server (StdioServerParameters ( command = "test" ) )
259253
260254 # Assert details about the raised error
261255 assert excinfo .value .error .code == types .INVALID_PARAMS
@@ -269,9 +263,127 @@ async def test_connect_to_server_duplicate_tool_raises_error(
269263 ) # Ensure it's the original mock
270264
271265 # No patching needed here
272- def test_disconnect_non_existent_server (self ): # No mock arguments needed
266+ def test_disconnect_non_existent_server (self ):
273267 """Test disconnecting a server that isn't connected."""
274268 session = mock .Mock (spec = mcp .ClientSession )
275269 group = ClientSessionGroup ()
276270 with pytest .raises (McpError ):
277271 group .disconnect_from_server (session )
272+
273+ @pytest .mark .parametrize (
274+ "server_params_instance, client_type_name, patch_target_for_client_func" ,
275+ [
276+ (
277+ StdioServerParameters (command = "test_stdio_cmd" ),
278+ "stdio" ,
279+ "mcp.client.session_group.mcp.stdio_client" ,
280+ ),
281+ (
282+ SseServerParameters (url = "http://test.com/sse" , timeout = 10 ),
283+ "sse" ,
284+ "mcp.client.session_group.sse_client" ,
285+ ), # url, headers, timeout, sse_read_timeout
286+ (
287+ StreamableHttpParameters (
288+ url = "http://test.com/stream" , terminate_on_close = False
289+ ),
290+ "streamablehttp" ,
291+ "mcp.client.session_group.streamablehttp_client" ,
292+ ), # url, headers, timeout, sse_read_timeout, terminate_on_close
293+ ],
294+ )
295+ async def test_establish_session_parameterized (
296+ self ,
297+ server_params_instance ,
298+ client_type_name , # Just for clarity or conditional logic if needed
299+ patch_target_for_client_func ,
300+ ):
301+ with mock .patch (
302+ "mcp.client.session_group.mcp.ClientSession"
303+ ) as mock_ClientSession_class :
304+ with mock .patch (patch_target_for_client_func ) as mock_specific_client_func :
305+ mock_client_cm_instance = mock .AsyncMock (
306+ name = f"{ client_type_name } ClientCM"
307+ )
308+ mock_read_stream = mock .AsyncMock (name = f"{ client_type_name } Read" )
309+ mock_write_stream = mock .AsyncMock (name = f"{ client_type_name } Write" )
310+
311+ # streamablehttp_client's __aenter__ returns three values
312+ if client_type_name == "streamablehttp" :
313+ mock_extra_stream_val = mock .AsyncMock (name = "StreamableExtra" )
314+ mock_client_cm_instance .__aenter__ .return_value = (
315+ mock_read_stream ,
316+ mock_write_stream ,
317+ mock_extra_stream_val ,
318+ )
319+ else :
320+ mock_client_cm_instance .__aenter__ .return_value = (
321+ mock_read_stream ,
322+ mock_write_stream ,
323+ )
324+
325+ mock_client_cm_instance .__aexit__ = mock .AsyncMock (return_value = None )
326+ mock_specific_client_func .return_value = mock_client_cm_instance
327+
328+ # --- Mock mcp.ClientSession (class) ---
329+ # mock_ClientSession_class is already provided by the outer patch
330+ mock_raw_session_cm = mock .AsyncMock (name = "RawSessionCM" )
331+ mock_ClientSession_class .return_value = mock_raw_session_cm
332+
333+ mock_entered_session = mock .AsyncMock (name = "EnteredSessionInstance" )
334+ mock_raw_session_cm .__aenter__ .return_value = mock_entered_session
335+ mock_raw_session_cm .__aexit__ = mock .AsyncMock (return_value = None )
336+
337+ # Mock session.initialize()
338+ mock_initialize_result = mock .AsyncMock (name = "InitializeResult" )
339+ mock_initialize_result .serverInfo = types .Implementation (
340+ name = "foo" , version = "1"
341+ )
342+ mock_entered_session .initialize .return_value = mock_initialize_result
343+
344+ # --- Test Execution ---
345+ group = ClientSessionGroup ()
346+ returned_server_info = None
347+ returned_session = None
348+
349+ async with contextlib .AsyncExitStack () as stack :
350+ group ._exit_stack = stack
351+ (
352+ returned_server_info ,
353+ returned_session ,
354+ ) = await group ._establish_session (server_params_instance )
355+
356+ # --- Assertions ---
357+ # 1. Assert the correct specific client function was called
358+ if client_type_name == "stdio" :
359+ mock_specific_client_func .assert_called_once_with (
360+ server_params_instance
361+ )
362+ elif client_type_name == "sse" :
363+ mock_specific_client_func .assert_called_once_with (
364+ url = server_params_instance .url ,
365+ headers = server_params_instance .headers ,
366+ timeout = server_params_instance .timeout ,
367+ sse_read_timeout = server_params_instance .sse_read_timeout ,
368+ )
369+ elif client_type_name == "streamablehttp" :
370+ mock_specific_client_func .assert_called_once_with (
371+ url = server_params_instance .url ,
372+ headers = server_params_instance .headers ,
373+ timeout = server_params_instance .timeout ,
374+ sse_read_timeout = server_params_instance .sse_read_timeout ,
375+ terminate_on_close = server_params_instance .terminate_on_close ,
376+ )
377+
378+ mock_client_cm_instance .__aenter__ .assert_awaited_once ()
379+
380+ # 2. Assert ClientSession was called correctly
381+ mock_ClientSession_class .assert_called_once_with (
382+ mock_read_stream , mock_write_stream
383+ )
384+ mock_raw_session_cm .__aenter__ .assert_awaited_once ()
385+ mock_entered_session .initialize .assert_awaited_once ()
386+
387+ # 3. Assert returned values
388+ assert returned_server_info is mock_initialize_result .serverInfo
389+ assert returned_session is mock_entered_session
0 commit comments