11use std:: { borrow:: Cow , sync:: Arc , time:: Duration } ;
22
3- use futures:: { StreamExt , future:: BoxFuture , stream:: BoxStream } ;
3+ use futures:: { Stream , StreamExt , future:: BoxFuture , stream:: BoxStream } ;
44pub use sse_stream:: Error as SseError ;
55use sse_stream:: Sse ;
66use thiserror:: Error ;
@@ -193,8 +193,7 @@ impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> {
193193 client : C :: default ( ) ,
194194 config : StreamableHttpClientTransportConfig {
195195 uri : url. into ( ) ,
196- retry_config : Arc :: new ( ExponentialBackoff :: default ( ) ) ,
197- channel_buffer_capacity : 16 ,
196+ ..Default :: default ( )
198197 } ,
199198 }
200199 }
@@ -208,7 +207,9 @@ impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
208207
209208impl < C : StreamableHttpClient > StreamableHttpClientWorker < C > {
210209 async fn execute_sse_stream (
211- sse_stream : SseAutoReconnectStream < StreamableHttpClientReconnect < C > > ,
210+ sse_stream : impl Stream < Item = Result < ServerJsonRpcMessage , StreamableHttpError < C :: Error > > >
211+ + Send
212+ + ' static ,
212213 sse_worker_tx : tokio:: sync:: mpsc:: Sender < ServerJsonRpcMessage > ,
213214 ct : CancellationToken ,
214215 ) -> Result < ( ) , StreamableHttpError < C :: Error > > {
@@ -277,16 +278,19 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
277278 . map_err ( WorkerQuitReason :: fatal_context (
278279 "process initialize response" ,
279280 ) ) ?;
280- let Some ( session_id) = session_id else {
281- return Err ( WorkerQuitReason :: fatal (
282- "missing session id in initialize response" ,
283- "process initialize response" ,
284- ) ) ;
281+ let session_id: Option < Arc < str > > = if let Some ( session_id) = session_id {
282+ Some ( session_id. into ( ) )
283+ } else {
284+ if !self . config . allow_stateless {
285+ return Err ( WorkerQuitReason :: fatal (
286+ "missing session id in initialize response" ,
287+ "process initialize response" ,
288+ ) ) ;
289+ }
290+ None
285291 } ;
286- let session_id: Arc < str > = session_id. into ( ) ;
287-
288292 // delete session when drop guard is dropped
289- {
293+ if let Some ( session_id ) = & session_id {
290294 let ct = transport_task_ct. clone ( ) ;
291295 let client = self . client . clone ( ) ;
292296 let session_id = session_id. clone ( ) ;
@@ -322,7 +326,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
322326 . post_message (
323327 config. uri . clone ( ) ,
324328 initialized_notification. message ,
325- Some ( session_id. clone ( ) ) ,
329+ session_id. clone ( ) ,
326330 None ,
327331 )
328332 . await
@@ -340,38 +344,40 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
340344 StreamResult ( Result < ( ) , StreamableHttpError < E > > ) ,
341345 }
342346 let mut streams = tokio:: task:: JoinSet :: new ( ) ;
343- match self
344- . client
345- . get_stream ( config. uri . clone ( ) , session_id. clone ( ) , None , None )
346- . await
347- {
348- Ok ( stream) => {
349- let sse_stream = SseAutoReconnectStream :: new (
350- stream,
351- StreamableHttpClientReconnect {
352- client : self . client . clone ( ) ,
353- session_id : session_id. clone ( ) ,
354- uri : config. uri . clone ( ) ,
355- } ,
356- self . config . retry_config . clone ( ) ,
357- ) ;
358- streams. spawn ( Self :: execute_sse_stream (
359- sse_stream,
360- sse_worker_tx. clone ( ) ,
361- transport_task_ct. child_token ( ) ,
362- ) ) ;
363- tracing:: debug!( "got common stream" ) ;
364- }
365- Err ( StreamableHttpError :: SeverDoesNotSupportSse ) => {
366- tracing:: debug!( "server doesn't support sse, skip common stream" ) ;
367- }
368- Err ( e) => {
369- // fail to get common stream
370- tracing:: error!( "fail to get common stream: {e}" ) ;
371- return Err ( WorkerQuitReason :: fatal (
372- "fail to get general purpose event stream" ,
373- "get general purpose event stream" ,
374- ) ) ;
347+ if let Some ( session_id) = & session_id {
348+ match self
349+ . client
350+ . get_stream ( config. uri . clone ( ) , session_id. clone ( ) , None , None )
351+ . await
352+ {
353+ Ok ( stream) => {
354+ let sse_stream = SseAutoReconnectStream :: new (
355+ stream,
356+ StreamableHttpClientReconnect {
357+ client : self . client . clone ( ) ,
358+ session_id : session_id. clone ( ) ,
359+ uri : config. uri . clone ( ) ,
360+ } ,
361+ self . config . retry_config . clone ( ) ,
362+ ) ;
363+ streams. spawn ( Self :: execute_sse_stream (
364+ sse_stream,
365+ sse_worker_tx. clone ( ) ,
366+ transport_task_ct. child_token ( ) ,
367+ ) ) ;
368+ tracing:: debug!( "got common stream" ) ;
369+ }
370+ Err ( StreamableHttpError :: SeverDoesNotSupportSse ) => {
371+ tracing:: debug!( "server doesn't support sse, skip common stream" ) ;
372+ }
373+ Err ( e) => {
374+ // fail to get common stream
375+ tracing:: error!( "fail to get common stream: {e}" ) ;
376+ return Err ( WorkerQuitReason :: fatal (
377+ "fail to get general purpose event stream" ,
378+ "get general purpose event stream" ,
379+ ) ) ;
380+ }
375381 }
376382 }
377383 loop {
@@ -407,7 +413,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
407413 let WorkerSendRequest { message, responder } = send_request;
408414 let response = self
409415 . client
410- . post_message ( config. uri . clone ( ) , message, Some ( session_id. clone ( ) ) , None )
416+ . post_message ( config. uri . clone ( ) , message, session_id. clone ( ) , None )
411417 . await ;
412418 let send_result = match response {
413419 Err ( e) => Err ( e) ,
@@ -420,20 +426,32 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
420426 Ok ( ( ) )
421427 }
422428 Ok ( StreamableHttpPostResponse :: Sse ( stream, ..) ) => {
423- let sse_stream = SseAutoReconnectStream :: new (
424- stream,
425- StreamableHttpClientReconnect {
426- client : self . client . clone ( ) ,
427- session_id : session_id. clone ( ) ,
428- uri : config. uri . clone ( ) ,
429- } ,
430- self . config . retry_config . clone ( ) ,
431- ) ;
432- streams. spawn ( Self :: execute_sse_stream (
433- sse_stream,
434- sse_worker_tx. clone ( ) ,
435- transport_task_ct. child_token ( ) ,
436- ) ) ;
429+ if let Some ( session_id) = & session_id {
430+ let sse_stream = SseAutoReconnectStream :: new (
431+ stream,
432+ StreamableHttpClientReconnect {
433+ client : self . client . clone ( ) ,
434+ session_id : session_id. clone ( ) ,
435+ uri : config. uri . clone ( ) ,
436+ } ,
437+ self . config . retry_config . clone ( ) ,
438+ ) ;
439+ streams. spawn ( Self :: execute_sse_stream (
440+ sse_stream,
441+ sse_worker_tx. clone ( ) ,
442+ transport_task_ct. child_token ( ) ,
443+ ) ) ;
444+ } else {
445+ let sse_stream = SseAutoReconnectStream :: never_reconnect (
446+ stream,
447+ StreamableHttpError :: < C :: Error > :: UnexpectedEndOfStream ,
448+ ) ;
449+ streams. spawn ( Self :: execute_sse_stream (
450+ sse_stream,
451+ sse_worker_tx. clone ( ) ,
452+ transport_task_ct. child_token ( ) ,
453+ ) ) ;
454+ }
437455 tracing:: trace!( "got new sse stream" ) ;
438456 Ok ( ( ) )
439457 }
@@ -470,6 +488,8 @@ pub struct StreamableHttpClientTransportConfig {
470488 pub uri : Arc < str > ,
471489 pub retry_config : Arc < dyn SseRetryPolicy > ,
472490 pub channel_buffer_capacity : usize ,
491+ /// if true, the transport will not require a session to be established
492+ pub allow_stateless : bool ,
473493}
474494
475495impl StreamableHttpClientTransportConfig {
@@ -487,6 +507,7 @@ impl Default for StreamableHttpClientTransportConfig {
487507 uri : "localhost" . into ( ) ,
488508 retry_config : Arc :: new ( ExponentialBackoff :: default ( ) ) ,
489509 channel_buffer_capacity : 16 ,
510+ allow_stateless : true ,
490511 }
491512 }
492513}
0 commit comments