From 09788c7316417989164d32eb730383b60e3a68da Mon Sep 17 00:00:00 2001 From: vvanglro Date: Mon, 19 May 2025 14:30:08 +0800 Subject: [PATCH] refactor(client): simplify context management and improve code structure - Replace explicit context manager usage with AsyncExitStack - Remove unnecessary comments and uncommented code- Improve error handling and resource cleanup- Update code formatting for better readability --- README.md | 16 ++++++---------- README_EN.md | 25 ++++++++++++------------- client.py | 17 +++++------------ main.py | 21 +++++++++------------ 4 files changed, 32 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index eb8d3d5..40658fc 100644 --- a/README.md +++ b/README.md @@ -482,7 +482,6 @@ import json import os from typing import Optional from contextlib import AsyncExitStack -import time from mcp import ClientSession from mcp.client.sse import sse_client @@ -501,11 +500,8 @@ class MCPClient: async def connect_to_sse_server(self, server_url: str): """Connect to an MCP server running with SSE transport""" # Store the context managers so they stay alive - self._streams_context = sse_client(url=server_url) - streams = await self._streams_context.__aenter__() - - self._session_context = ClientSession(*streams) - self.session: ClientSession = await self._session_context.__aenter__() + streams = await self.exit_stack.enter_async_context(sse_client(url=server_url)) + self.session = await self.exit_stack.enter_async_context(ClientSession(*streams)) # Initialize await self.session.initialize() @@ -519,10 +515,7 @@ class MCPClient: async def cleanup(self): """Properly clean up the session and streams""" - if self._session_context: - await self._session_context.__aexit__(None, None, None) - if self._streams_context: - await self._streams_context.__aexit__(None, None, None) + await self.exit_stack.aclose() async def process_query(self, query: str) -> str: """Process a query using OpenAI API and available tools""" @@ -600,6 +593,7 @@ class MCPClient: final_text.append(assistant_message.content) return "\n".join(final_text) + async def chat_loop(self): """Run an interactive chat loop""" @@ -619,6 +613,7 @@ class MCPClient: except Exception as e: print(f"\nError: {str(e)}") + async def main(): if len(sys.argv) < 2: print("Usage: uv run client.py ") @@ -631,6 +626,7 @@ async def main(): finally: await client.cleanup() + if __name__ == "__main__": import sys asyncio.run(main()) diff --git a/README_EN.md b/README_EN.md index 444d33c..7781256 100644 --- a/README_EN.md +++ b/README_EN.md @@ -462,7 +462,6 @@ import json import os from typing import Optional from contextlib import AsyncExitStack -import time from mcp import ClientSession from mcp.client.sse import sse_client @@ -481,11 +480,8 @@ class MCPClient: async def connect_to_sse_server(self, server_url: str): """Connect to an MCP server running with SSE transport""" # Store the context managers so they stay alive - self._streams_context = sse_client(url=server_url) - streams = await self._streams_context.__aenter__() - - self._session_context = ClientSession(*streams) - self.session: ClientSession = await self._session_context.__aenter__() + streams = await self.exit_stack.enter_async_context(sse_client(url=server_url)) + self.session = await self.exit_stack.enter_async_context(ClientSession(*streams)) # Initialize await self.session.initialize() @@ -499,10 +495,7 @@ class MCPClient: async def cleanup(self): """Properly clean up the session and streams""" - if self._session_context: - await self._session_context.__aexit__(None, None, None) - if self._streams_context: - await self._streams_context.__aexit__(None, None, None) + await self.exit_stack.aclose() async def process_query(self, query: str) -> str: """Process a query using OpenAI API and available tools""" @@ -542,11 +535,14 @@ class MCPClient: tool_name = tool_call.function.name tool_args = json.loads(tool_call.function.arguments) + + # Execute tool call result = await self.session.call_tool(tool_name, tool_args) tool_results.append({"call": tool_name, "result": result}) final_text.append(f"[Calling tool {tool_name} with args {tool_args}]") - + + # Continue conversation with tool results messages.extend([ { @@ -561,8 +557,8 @@ class MCPClient: } ]) - print(f"Tool {tool_name} returned: {result.content[0].text}") - print("messages", messages) + # print(f"Tool {tool_name} returned: {result.content[0].text}") + # print("messages", messages) # Get next response from OpenAI completion = await self.openai.chat.completions.create( model=os.getenv("OPENAI_MODEL"), @@ -580,6 +576,7 @@ class MCPClient: final_text.append(assistant_message.content) return "\n".join(final_text) + async def chat_loop(self): """Run an interactive chat loop""" @@ -599,6 +596,7 @@ class MCPClient: except Exception as e: print(f"\nError: {str(e)}") + async def main(): if len(sys.argv) < 2: print("Usage: uv run client.py ") @@ -611,6 +609,7 @@ async def main(): finally: await client.cleanup() + if __name__ == "__main__": import sys asyncio.run(main()) diff --git a/client.py b/client.py index 8ed0cb8..0bdceab 100644 --- a/client.py +++ b/client.py @@ -3,7 +3,6 @@ import os from typing import Optional from contextlib import AsyncExitStack -import time from mcp import ClientSession from mcp.client.sse import sse_client @@ -22,11 +21,8 @@ def __init__(self): async def connect_to_sse_server(self, server_url: str): """Connect to an MCP server running with SSE transport""" # Store the context managers so they stay alive - self._streams_context = sse_client(url=server_url) - streams = await self._streams_context.__aenter__() - - self._session_context = ClientSession(*streams) - self.session: ClientSession = await self._session_context.__aenter__() + streams = await self.exit_stack.enter_async_context(sse_client(url=server_url)) + self.session = await self.exit_stack.enter_async_context(ClientSession(*streams)) # Initialize await self.session.initialize() @@ -40,10 +36,7 @@ async def connect_to_sse_server(self, server_url: str): async def cleanup(self): """Properly clean up the session and streams""" - if self._session_context: - await self._session_context.__aexit__(None, None, None) - if self._streams_context: - await self._streams_context.__aexit__(None, None, None) + await self.exit_stack.aclose() async def process_query(self, query: str) -> str: """Process a query using OpenAI API and available tools""" @@ -105,8 +98,8 @@ async def process_query(self, query: str) -> str: } ]) - # print(f"Tool {tool_name} returned: {result.content[0].text}") - # print("messages", messages) + print(f"Tool {tool_name} returned: {result.content[0].text}") + print("messages", messages) # Get next response from OpenAI completion = await self.openai.chat.completions.create( model=os.getenv("OPENAI_MODEL"), diff --git a/main.py b/main.py index 44e99c6..e376eb8 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,8 @@ # main.py -from mcp.server.fastmcp import FastMCP from dotenv import load_dotenv -import httpx import json import os from bs4 import BeautifulSoup -from typing import Any import httpx from mcp.server.fastmcp import FastMCP from starlette.applications import Starlette @@ -92,13 +89,13 @@ async def get_docs(query: str, library: str): # Stdio协议 if __name__ == "__main__": mcp.run(transport="stdio") - + # # SSE协议 # def create_starlette_app(mcp_server: Server, *, debug: bool = False) -> Starlette: # """Create a Starlette application that can server the provied mcp server with SSE.""" # sse = SseServerTransport("/messages/") - +# # async def handle_sse(request: Request) -> None: # async with sse.connect_sse( # request.scope, @@ -110,7 +107,7 @@ async def get_docs(query: str, library: str): # write_stream, # mcp_server.create_initialization_options(), # ) - +# # return Starlette( # debug=debug, # routes=[ @@ -118,18 +115,18 @@ async def get_docs(query: str, library: str): # Mount("/messages/", app=sse.handle_post_message), # ], # ) - +# # if __name__ == "__main__": -# mcp_server = mcp._mcp_server - +# mcp_server = mcp._mcp_server +# # import argparse - +# # parser = argparse.ArgumentParser(description='Run MCP SSE-based server') # parser.add_argument('--host', default='0.0.0.0', help='Host to bind to') # parser.add_argument('--port', type=int, default=8020, help='Port to listen on') # args = parser.parse_args() - +# # # Bind SSE request handling to MCP server # starlette_app = create_starlette_app(mcp_server, debug=True) - +# # uvicorn.run(starlette_app, host=args.host, port=args.port)