From dbc75379c8190db0271fd7af66285a9342c991bc Mon Sep 17 00:00:00 2001 From: 4t145 Date: Wed, 28 May 2025 19:24:16 +0800 Subject: [PATCH] feat: stateless mode of streamable http client --- .../src/transport/common/client_side_sse.rs | 36 +++++ .../src/transport/streamable_http_client.rs | 141 ++++++++++-------- 2 files changed, 117 insertions(+), 60 deletions(-) diff --git a/crates/rmcp/src/transport/common/client_side_sse.rs b/crates/rmcp/src/transport/common/client_side_sse.rs index f3549627..6fbf67ea 100644 --- a/crates/rmcp/src/transport/common/client_side_sse.rs +++ b/crates/rmcp/src/transport/common/client_side_sse.rs @@ -76,6 +76,28 @@ impl SseRetryPolicy for ExponentialBackoff { } } +#[derive(Debug, Clone, Copy, Default)] +pub struct NeverRetry; + +impl SseRetryPolicy for NeverRetry { + fn retry(&self, _current_times: usize) -> Option { + None + } +} + +#[derive(Debug, Default)] +pub struct NeverReconnect { + error: Option, +} + +impl SseStreamReconnect for NeverReconnect { + type Error = E; + type Future = futures::future::Ready>; + fn retry_connection(&mut self, _last_event_id: Option<&str>) -> Self::Future { + futures::future::ready(Err(self.error.take().expect("should not be called again"))) + } +} + pub(crate) trait SseStreamReconnect { type Error: std::error::Error; type Future: Future> + Send; @@ -111,6 +133,20 @@ impl SseAutoReconnectStream { } } +impl SseAutoReconnectStream> { + pub fn never_reconnect(stream: BoxedSseResponse, error_when_reconnect: E) -> Self { + Self { + retry_policy: Arc::new(NeverRetry), + last_event_id: None, + server_retry_interval: None, + connector: NeverReconnect { + error: Some(error_when_reconnect), + }, + state: SseAutoReconnectStreamState::Connected { stream }, + } + } +} + pin_project_lite::pin_project! { #[project = SseAutoReconnectStreamStateProj] pub enum SseAutoReconnectStreamState { diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 4c5c7c99..7ec20fad 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -1,6 +1,6 @@ use std::{borrow::Cow, sync::Arc, time::Duration}; -use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; +use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream}; pub use sse_stream::Error as SseError; use sse_stream::Sse; use thiserror::Error; @@ -193,8 +193,7 @@ impl StreamableHttpClientWorker { client: C::default(), config: StreamableHttpClientTransportConfig { uri: url.into(), - retry_config: Arc::new(ExponentialBackoff::default()), - channel_buffer_capacity: 16, + ..Default::default() }, } } @@ -208,7 +207,9 @@ impl StreamableHttpClientWorker { impl StreamableHttpClientWorker { async fn execute_sse_stream( - sse_stream: SseAutoReconnectStream>, + sse_stream: impl Stream>> + + Send + + 'static, sse_worker_tx: tokio::sync::mpsc::Sender, ct: CancellationToken, ) -> Result<(), StreamableHttpError> { @@ -277,16 +278,19 @@ impl Worker for StreamableHttpClientWorker { .map_err(WorkerQuitReason::fatal_context( "process initialize response", ))?; - let Some(session_id) = session_id else { - return Err(WorkerQuitReason::fatal( - "missing session id in initialize response", - "process initialize response", - )); + let session_id: Option> = if let Some(session_id) = session_id { + Some(session_id.into()) + } else { + if !self.config.allow_stateless { + return Err(WorkerQuitReason::fatal( + "missing session id in initialize response", + "process initialize response", + )); + } + None }; - let session_id: Arc = session_id.into(); - // delete session when drop guard is dropped - { + if let Some(session_id) = &session_id { let ct = transport_task_ct.clone(); let client = self.client.clone(); let session_id = session_id.clone(); @@ -322,7 +326,7 @@ impl Worker for StreamableHttpClientWorker { .post_message( config.uri.clone(), initialized_notification.message, - Some(session_id.clone()), + session_id.clone(), None, ) .await @@ -340,38 +344,40 @@ impl Worker for StreamableHttpClientWorker { StreamResult(Result<(), StreamableHttpError>), } let mut streams = tokio::task::JoinSet::new(); - match self - .client - .get_stream(config.uri.clone(), session_id.clone(), None, None) - .await - { - Ok(stream) => { - let sse_stream = SseAutoReconnectStream::new( - stream, - StreamableHttpClientReconnect { - client: self.client.clone(), - session_id: session_id.clone(), - uri: config.uri.clone(), - }, - self.config.retry_config.clone(), - ); - streams.spawn(Self::execute_sse_stream( - sse_stream, - sse_worker_tx.clone(), - transport_task_ct.child_token(), - )); - tracing::debug!("got common stream"); - } - Err(StreamableHttpError::SeverDoesNotSupportSse) => { - tracing::debug!("server doesn't support sse, skip common stream"); - } - Err(e) => { - // fail to get common stream - tracing::error!("fail to get common stream: {e}"); - return Err(WorkerQuitReason::fatal( - "fail to get general purpose event stream", - "get general purpose event stream", - )); + if let Some(session_id) = &session_id { + match self + .client + .get_stream(config.uri.clone(), session_id.clone(), None, None) + .await + { + Ok(stream) => { + let sse_stream = SseAutoReconnectStream::new( + stream, + StreamableHttpClientReconnect { + client: self.client.clone(), + session_id: session_id.clone(), + uri: config.uri.clone(), + }, + self.config.retry_config.clone(), + ); + streams.spawn(Self::execute_sse_stream( + sse_stream, + sse_worker_tx.clone(), + transport_task_ct.child_token(), + )); + tracing::debug!("got common stream"); + } + Err(StreamableHttpError::SeverDoesNotSupportSse) => { + tracing::debug!("server doesn't support sse, skip common stream"); + } + Err(e) => { + // fail to get common stream + tracing::error!("fail to get common stream: {e}"); + return Err(WorkerQuitReason::fatal( + "fail to get general purpose event stream", + "get general purpose event stream", + )); + } } } loop { @@ -407,7 +413,7 @@ impl Worker for StreamableHttpClientWorker { let WorkerSendRequest { message, responder } = send_request; let response = self .client - .post_message(config.uri.clone(), message, Some(session_id.clone()), None) + .post_message(config.uri.clone(), message, session_id.clone(), None) .await; let send_result = match response { Err(e) => Err(e), @@ -420,20 +426,32 @@ impl Worker for StreamableHttpClientWorker { Ok(()) } Ok(StreamableHttpPostResponse::Sse(stream, ..)) => { - let sse_stream = SseAutoReconnectStream::new( - stream, - StreamableHttpClientReconnect { - client: self.client.clone(), - session_id: session_id.clone(), - uri: config.uri.clone(), - }, - self.config.retry_config.clone(), - ); - streams.spawn(Self::execute_sse_stream( - sse_stream, - sse_worker_tx.clone(), - transport_task_ct.child_token(), - )); + if let Some(session_id) = &session_id { + let sse_stream = SseAutoReconnectStream::new( + stream, + StreamableHttpClientReconnect { + client: self.client.clone(), + session_id: session_id.clone(), + uri: config.uri.clone(), + }, + self.config.retry_config.clone(), + ); + streams.spawn(Self::execute_sse_stream( + sse_stream, + sse_worker_tx.clone(), + transport_task_ct.child_token(), + )); + } else { + let sse_stream = SseAutoReconnectStream::never_reconnect( + stream, + StreamableHttpError::::UnexpectedEndOfStream, + ); + streams.spawn(Self::execute_sse_stream( + sse_stream, + sse_worker_tx.clone(), + transport_task_ct.child_token(), + )); + } tracing::trace!("got new sse stream"); Ok(()) } @@ -470,6 +488,8 @@ pub struct StreamableHttpClientTransportConfig { pub uri: Arc, pub retry_config: Arc, pub channel_buffer_capacity: usize, + /// if true, the transport will not require a session to be established + pub allow_stateless: bool, } impl StreamableHttpClientTransportConfig { @@ -487,6 +507,7 @@ impl Default for StreamableHttpClientTransportConfig { uri: "localhost".into(), retry_config: Arc::new(ExponentialBackoff::default()), channel_buffer_capacity: 16, + allow_stateless: true, } } }