@@ -129,6 +129,8 @@ class StreamableHTTPServerTransport:
129129 _read_stream_writer : MemoryObjectSendStream [SessionMessage | Exception ] | None = (
130130 None
131131 )
132+ _read_stream : MemoryObjectReceiveStream [SessionMessage | Exception ] | None = None
133+ _write_stream : MemoryObjectSendStream [SessionMessage ] | None = None
132134 _write_stream_reader : MemoryObjectReceiveStream [SessionMessage ] | None = None
133135
134136 def __init__ (
@@ -163,7 +165,11 @@ def __init__(
163165 self .is_json_response_enabled = is_json_response_enabled
164166 self ._event_store = event_store
165167 self ._request_streams : dict [
166- RequestId , MemoryObjectSendStream [EventMessage ]
168+ RequestId ,
169+ tuple [
170+ MemoryObjectSendStream [EventMessage ],
171+ MemoryObjectReceiveStream [EventMessage ],
172+ ],
167173 ] = {}
168174 self ._terminated = False
169175
@@ -239,6 +245,19 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
239245
240246 return event_data
241247
248+ async def _clean_up_memory_streams (self , request_id : RequestId ) -> None :
249+ """Clean up memory streams for a given request ID."""
250+ if request_id in self ._request_streams :
251+ try :
252+ # Close the request stream
253+ await self ._request_streams [request_id ][0 ].aclose ()
254+ await self ._request_streams [request_id ][1 ].aclose ()
255+ except Exception as e :
256+ logger .debug (f"Error closing memory streams: { e } " )
257+ finally :
258+ # Remove the request stream from the mapping
259+ self ._request_streams .pop (request_id , None )
260+
242261 async def handle_request (self , scope : Scope , receive : Receive , send : Send ) -> None :
243262 """Application entry point that handles all HTTP requests"""
244263 request = Request (scope , receive )
@@ -386,13 +405,11 @@ async def _handle_post_request(
386405
387406 # Extract the request ID outside the try block for proper scope
388407 request_id = str (message .root .id )
389- # Create promise stream for getting response
390- request_stream_writer , request_stream_reader = (
391- anyio .create_memory_object_stream [EventMessage ](0 )
392- )
393-
394408 # Register this stream for the request ID
395- self ._request_streams [request_id ] = request_stream_writer
409+ self ._request_streams [request_id ] = anyio .create_memory_object_stream [
410+ EventMessage
411+ ](0 )
412+ request_stream_reader = self ._request_streams [request_id ][1 ]
396413
397414 if self .is_json_response_enabled :
398415 # Process the message
@@ -441,11 +458,7 @@ async def _handle_post_request(
441458 )
442459 await response (scope , receive , send )
443460 finally :
444- # Clean up the request stream
445- if request_id in self ._request_streams :
446- self ._request_streams .pop (request_id , None )
447- await request_stream_reader .aclose ()
448- await request_stream_writer .aclose ()
461+ await self ._clean_up_memory_streams (request_id )
449462 else :
450463 # Create SSE stream
451464 sse_stream_writer , sse_stream_reader = (
@@ -467,16 +480,12 @@ async def sse_writer():
467480 event_message .message .root ,
468481 JSONRPCResponse | JSONRPCError ,
469482 ):
470- if request_id :
471- self ._request_streams .pop (request_id , None )
472483 break
473484 except Exception as e :
474485 logger .exception (f"Error in SSE writer: { e } " )
475486 finally :
476487 logger .debug ("Closing SSE writer" )
477- # Clean up the request-specific streams
478- if request_id and request_id in self ._request_streams :
479- self ._request_streams .pop (request_id , None )
488+ await self ._clean_up_memory_streams (request_id )
480489
481490 # Create and start EventSourceResponse
482491 # SSE stream mode (original behavior)
@@ -507,9 +516,9 @@ async def sse_writer():
507516 await writer .send (session_message )
508517 except Exception :
509518 logger .exception ("SSE response error" )
510- # Clean up the request stream if something goes wrong
511- if request_id and request_id in self . _request_streams :
512- self ._request_streams . pop (request_id , None )
519+ await sse_stream_writer . aclose ()
520+ await sse_stream_reader . aclose ()
521+ await self ._clean_up_memory_streams (request_id )
513522
514523 except Exception as err :
515524 logger .exception ("Error handling POST request" )
@@ -581,12 +590,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
581590 async def standalone_sse_writer ():
582591 try :
583592 # Create a standalone message stream for server-initiated messages
584- standalone_stream_writer , standalone_stream_reader = (
593+
594+ self ._request_streams [GET_STREAM_KEY ] = (
585595 anyio .create_memory_object_stream [EventMessage ](0 )
586596 )
587-
588- # Register this stream using the special key
589- self ._request_streams [GET_STREAM_KEY ] = standalone_stream_writer
597+ standalone_stream_reader = self ._request_streams [GET_STREAM_KEY ][1 ]
590598
591599 async with sse_stream_writer , standalone_stream_reader :
592600 # Process messages from the standalone stream
@@ -603,8 +611,7 @@ async def standalone_sse_writer():
603611 logger .exception (f"Error in standalone SSE writer: { e } " )
604612 finally :
605613 logger .debug ("Closing standalone SSE writer" )
606- # Remove the stream from request_streams
607- self ._request_streams .pop (GET_STREAM_KEY , None )
614+ await self ._clean_up_memory_streams (GET_STREAM_KEY )
608615
609616 # Create and start EventSourceResponse
610617 response = EventSourceResponse (
@@ -618,8 +625,9 @@ async def standalone_sse_writer():
618625 await response (request .scope , request .receive , send )
619626 except Exception as e :
620627 logger .exception (f"Error in standalone SSE response: { e } " )
621- # Clean up the request stream
622- self ._request_streams .pop (GET_STREAM_KEY , None )
628+ await sse_stream_writer .aclose ()
629+ await sse_stream_reader .aclose ()
630+ await self ._clean_up_memory_streams (GET_STREAM_KEY )
623631
624632 async def _handle_delete_request (self , request : Request , send : Send ) -> None :
625633 """Handle DELETE requests for explicit session termination."""
@@ -636,15 +644,15 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
636644 if not await self ._validate_session (request , send ):
637645 return
638646
639- self ._terminate_session ()
647+ await self ._terminate_session ()
640648
641649 response = self ._create_json_response (
642650 None ,
643651 HTTPStatus .OK ,
644652 )
645653 await response (request .scope , request .receive , send )
646654
647- def _terminate_session (self ) -> None :
655+ async def _terminate_session (self ) -> None :
648656 """Terminate the current session, closing all streams.
649657
650658 Once terminated, all requests with this session ID will receive 404 Not Found.
@@ -656,19 +664,26 @@ def _terminate_session(self) -> None:
656664 # We need a copy of the keys to avoid modification during iteration
657665 request_stream_keys = list (self ._request_streams .keys ())
658666
659- # Close all request streams (synchronously)
667+ # Close all request streams asynchronously
660668 for key in request_stream_keys :
661669 try :
662- # Get the stream
663- stream = self ._request_streams .get (key )
664- if stream :
665- # We must use close() here, not aclose() since this is a sync method
666- stream .close ()
670+ await self ._clean_up_memory_streams (key )
667671 except Exception as e :
668672 logger .debug (f"Error closing stream { key } during termination: { e } " )
669673
670674 # Clear the request streams dictionary immediately
671675 self ._request_streams .clear ()
676+ try :
677+ if self ._read_stream_writer is not None :
678+ await self ._read_stream_writer .aclose ()
679+ if self ._read_stream is not None :
680+ await self ._read_stream .aclose ()
681+ if self ._write_stream_reader is not None :
682+ await self ._write_stream_reader .aclose ()
683+ if self ._write_stream is not None :
684+ await self ._write_stream .aclose ()
685+ except Exception as e :
686+ logger .debug (f"Error closing streams: { e } " )
672687
673688 async def _handle_unsupported_request (self , request : Request , send : Send ) -> None :
674689 """Handle unsupported HTTP methods."""
@@ -756,10 +771,10 @@ async def send_event(event_message: EventMessage) -> None:
756771
757772 # If stream ID not in mapping, create it
758773 if stream_id and stream_id not in self ._request_streams :
759- msg_writer , msg_reader = anyio . create_memory_object_stream [
760- EventMessage
761- ]( 0 )
762- self ._request_streams [stream_id ] = msg_writer
774+ self . _request_streams [ stream_id ] = (
775+ anyio . create_memory_object_stream [ EventMessage ]( 0 )
776+ )
777+ msg_reader = self ._request_streams [stream_id ][ 1 ]
763778
764779 # Forward messages to SSE
765780 async with msg_reader :
@@ -781,6 +796,9 @@ async def send_event(event_message: EventMessage) -> None:
781796 await response (request .scope , request .receive , send )
782797 except Exception as e :
783798 logger .exception (f"Error in replay response: { e } " )
799+ finally :
800+ await sse_stream_writer .aclose ()
801+ await sse_stream_reader .aclose ()
784802
785803 except Exception as e :
786804 logger .exception (f"Error replaying events: { e } " )
@@ -818,7 +836,9 @@ async def connect(
818836
819837 # Store the streams
820838 self ._read_stream_writer = read_stream_writer
839+ self ._read_stream = read_stream
821840 self ._write_stream_reader = write_stream_reader
841+ self ._write_stream = write_stream
822842
823843 # Start a task group for message routing
824844 async with anyio .create_task_group () as tg :
@@ -863,7 +883,7 @@ async def message_router():
863883 if request_stream_id in self ._request_streams :
864884 try :
865885 # Send both the message and the event ID
866- await self ._request_streams [request_stream_id ].send (
886+ await self ._request_streams [request_stream_id ][ 0 ] .send (
867887 EventMessage (message , event_id )
868888 )
869889 except (
@@ -872,6 +892,12 @@ async def message_router():
872892 ):
873893 # Stream might be closed, remove from registry
874894 self ._request_streams .pop (request_stream_id , None )
895+ else :
896+ logging .debug (
897+ f"""Request stream { request_stream_id } not found
898+ for message. Still processing message as the client
899+ might reconnect and replay."""
900+ )
875901 except Exception as e :
876902 logger .exception (f"Error in message router: { e } " )
877903
@@ -882,9 +908,19 @@ async def message_router():
882908 # Yield the streams for the caller to use
883909 yield read_stream , write_stream
884910 finally :
885- for stream in list (self ._request_streams .values ()):
911+ for stream_id in list (self ._request_streams .keys ()):
886912 try :
887- await stream .aclose ()
888- except Exception :
913+ await self ._clean_up_memory_streams (stream_id )
914+ except Exception as e :
915+ logger .debug (f"Error closing request stream: { e } " )
889916 pass
890917 self ._request_streams .clear ()
918+
919+ # Clean up the read and write streams
920+ try :
921+ await read_stream_writer .aclose ()
922+ await read_stream .aclose ()
923+ await write_stream_reader .aclose ()
924+ await write_stream .aclose ()
925+ except Exception as e :
926+ logger .debug (f"Error closing streams: { e } " )
0 commit comments