2222from mcp .shared .message import ClientMessageMetadata , SessionMessage
2323from mcp .types import (
2424 ErrorData ,
25+ InitializeResult ,
2526 JSONRPCError ,
2627 JSONRPCMessage ,
2728 JSONRPCNotification ,
3940GetSessionIdCallback = Callable [[], str | None ]
4041
4142MCP_SESSION_ID = "mcp-session-id"
43+ MCP_PROTOCOL_VERSION = "mcp-protocol-version"
4244LAST_EVENT_ID = "last-event-id"
4345CONTENT_TYPE = "content-type"
4446ACCEPT = "Accept"
@@ -97,17 +99,20 @@ def __init__(
9799 )
98100 self .auth = auth
99101 self .session_id = None
102+ self .protocol_version = None
100103 self .request_headers = {
101104 ACCEPT : f"{ JSON } , { SSE } " ,
102105 CONTENT_TYPE : JSON ,
103106 ** self .headers ,
104107 }
105108
106- def _update_headers_with_session (self , base_headers : dict [str , str ]) -> dict [str , str ]:
107- """Update headers with session ID if available."""
109+ def _prepare_request_headers (self , base_headers : dict [str , str ]) -> dict [str , str ]:
110+ """Update headers with session ID and protocol version if available."""
108111 headers = base_headers .copy ()
109112 if self .session_id :
110113 headers [MCP_SESSION_ID ] = self .session_id
114+ if self .protocol_version :
115+ headers [MCP_PROTOCOL_VERSION ] = self .protocol_version
111116 return headers
112117
113118 def _is_initialization_request (self , message : JSONRPCMessage ) -> bool :
@@ -128,19 +133,39 @@ def _maybe_extract_session_id_from_response(
128133 self .session_id = new_session_id
129134 logger .info (f"Received session ID: { self .session_id } " )
130135
136+ def _maybe_extract_protocol_version_from_message (
137+ self ,
138+ message : JSONRPCMessage ,
139+ ) -> None :
140+ """Extract protocol version from initialization response message."""
141+ if isinstance (message .root , JSONRPCResponse ) and message .root .result :
142+ try :
143+ # Parse the result as InitializeResult for type safety
144+ init_result = InitializeResult .model_validate (message .root .result )
145+ self .protocol_version = str (init_result .protocolVersion )
146+ logger .info (f"Negotiated protocol version: { self .protocol_version } " )
147+ except Exception as exc :
148+ logger .warning (f"Failed to parse initialization response as InitializeResult: { exc } " )
149+ logger .warning (f"Raw result: { message .root .result } " )
150+
131151 async def _handle_sse_event (
132152 self ,
133153 sse : ServerSentEvent ,
134154 read_stream_writer : StreamWriter ,
135155 original_request_id : RequestId | None = None ,
136156 resumption_callback : Callable [[str ], Awaitable [None ]] | None = None ,
157+ is_initialization : bool = False ,
137158 ) -> bool :
138159 """Handle an SSE event, returning True if the response is complete."""
139160 if sse .event == "message" :
140161 try :
141162 message = JSONRPCMessage .model_validate_json (sse .data )
142163 logger .debug (f"SSE message: { message } " )
143164
165+ # Extract protocol version from initialization response
166+ if is_initialization :
167+ self ._maybe_extract_protocol_version_from_message (message )
168+
144169 # If this is a response and we have original_request_id, replace it
145170 if original_request_id is not None and isinstance (message .root , JSONRPCResponse | JSONRPCError ):
146171 message .root .id = original_request_id
@@ -174,7 +199,7 @@ async def handle_get_stream(
174199 if not self .session_id :
175200 return
176201
177- headers = self ._update_headers_with_session (self .request_headers )
202+ headers = self ._prepare_request_headers (self .request_headers )
178203
179204 async with aconnect_sse (
180205 client ,
@@ -194,7 +219,7 @@ async def handle_get_stream(
194219
195220 async def _handle_resumption_request (self , ctx : RequestContext ) -> None :
196221 """Handle a resumption request using GET with SSE."""
197- headers = self ._update_headers_with_session (ctx .headers )
222+ headers = self ._prepare_request_headers (ctx .headers )
198223 if ctx .metadata and ctx .metadata .resumption_token :
199224 headers [LAST_EVENT_ID ] = ctx .metadata .resumption_token
200225 else :
@@ -227,7 +252,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
227252
228253 async def _handle_post_request (self , ctx : RequestContext ) -> None :
229254 """Handle a POST request with response processing."""
230- headers = self ._update_headers_with_session (ctx .headers )
255+ headers = self ._prepare_request_headers (ctx .headers )
231256 message = ctx .session_message .message
232257 is_initialization = self ._is_initialization_request (message )
233258
@@ -256,9 +281,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
256281 content_type = response .headers .get (CONTENT_TYPE , "" ).lower ()
257282
258283 if content_type .startswith (JSON ):
259- await self ._handle_json_response (response , ctx .read_stream_writer )
284+ await self ._handle_json_response (response , ctx .read_stream_writer , is_initialization )
260285 elif content_type .startswith (SSE ):
261- await self ._handle_sse_response (response , ctx )
286+ await self ._handle_sse_response (response , ctx , is_initialization )
262287 else :
263288 await self ._handle_unexpected_content_type (
264289 content_type ,
@@ -269,18 +294,29 @@ async def _handle_json_response(
269294 self ,
270295 response : httpx .Response ,
271296 read_stream_writer : StreamWriter ,
297+ is_initialization : bool = False ,
272298 ) -> None :
273299 """Handle JSON response from the server."""
274300 try :
275301 content = await response .aread ()
276302 message = JSONRPCMessage .model_validate_json (content )
303+
304+ # Extract protocol version from initialization response
305+ if is_initialization :
306+ self ._maybe_extract_protocol_version_from_message (message )
307+
277308 session_message = SessionMessage (message )
278309 await read_stream_writer .send (session_message )
279310 except Exception as exc :
280311 logger .error (f"Error parsing JSON response: { exc } " )
281312 await read_stream_writer .send (exc )
282313
283- async def _handle_sse_response (self , response : httpx .Response , ctx : RequestContext ) -> None :
314+ async def _handle_sse_response (
315+ self ,
316+ response : httpx .Response ,
317+ ctx : RequestContext ,
318+ is_initialization : bool = False ,
319+ ) -> None :
284320 """Handle SSE response from the server."""
285321 try :
286322 event_source = EventSource (response )
@@ -289,6 +325,7 @@ async def _handle_sse_response(self, response: httpx.Response, ctx: RequestConte
289325 sse ,
290326 ctx .read_stream_writer ,
291327 resumption_callback = (ctx .metadata .on_resumption_token_update if ctx .metadata else None ),
328+ is_initialization = is_initialization ,
292329 )
293330 # If the SSE event indicates completion, like returning respose/error
294331 # break the loop
@@ -385,7 +422,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
385422 return
386423
387424 try :
388- headers = self ._update_headers_with_session (self .request_headers )
425+ headers = self ._prepare_request_headers (self .request_headers )
389426 response = await client .delete (self .url , headers = headers )
390427
391428 if response .status_code == 405 :
0 commit comments