Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions crates/rmcp/src/transport/common/client_side_sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,28 @@ impl SseRetryPolicy for ExponentialBackoff {
}
}

#[derive(Debug, Clone, Copy, Default)]
pub struct NeverRetry;

impl SseRetryPolicy for NeverRetry {
fn retry(&self, _current_times: usize) -> Option<Duration> {
None
}
}

#[derive(Debug, Default)]
pub struct NeverReconnect<E> {
error: Option<E>,
}

impl<E: std::error::Error + Send> SseStreamReconnect for NeverReconnect<E> {
type Error = E;
type Future = futures::future::Ready<Result<BoxedSseResponse, Self::Error>>;
fn retry_connection(&mut self, _last_event_id: Option<&str>) -> Self::Future {
futures::future::ready(Err(self.error.take().expect("should not be called again")))
}
}

pub(crate) trait SseStreamReconnect {
type Error: std::error::Error;
type Future: Future<Output = Result<BoxedSseResponse, Self::Error>> + Send;
Expand Down Expand Up @@ -111,6 +133,20 @@ impl<R: SseStreamReconnect> SseAutoReconnectStream<R> {
}
}

impl<E: std::error::Error + Send> SseAutoReconnectStream<NeverReconnect<E>> {
pub fn never_reconnect(stream: BoxedSseResponse, error_when_reconnect: E) -> Self {
Self {
retry_policy: Arc::new(NeverRetry),
last_event_id: None,
server_retry_interval: None,
connector: NeverReconnect {
error: Some(error_when_reconnect),
},
state: SseAutoReconnectStreamState::Connected { stream },
}
}
}

pin_project_lite::pin_project! {
#[project = SseAutoReconnectStreamStateProj]
pub enum SseAutoReconnectStreamState<F> {
Expand Down
141 changes: 81 additions & 60 deletions crates/rmcp/src/transport/streamable_http_client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{borrow::Cow, sync::Arc, time::Duration};

use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream};
pub use sse_stream::Error as SseError;
use sse_stream::Sse;
use thiserror::Error;
Expand Down Expand Up @@ -193,8 +193,7 @@ impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> {
client: C::default(),
config: StreamableHttpClientTransportConfig {
uri: url.into(),
retry_config: Arc::new(ExponentialBackoff::default()),
channel_buffer_capacity: 16,
..Default::default()
},
}
}
Expand All @@ -208,7 +207,9 @@ impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {

impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
async fn execute_sse_stream(
sse_stream: SseAutoReconnectStream<StreamableHttpClientReconnect<C>>,
sse_stream: impl Stream<Item = Result<ServerJsonRpcMessage, StreamableHttpError<C::Error>>>
+ Send
+ 'static,
sse_worker_tx: tokio::sync::mpsc::Sender<ServerJsonRpcMessage>,
ct: CancellationToken,
) -> Result<(), StreamableHttpError<C::Error>> {
Expand Down Expand Up @@ -277,16 +278,19 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
.map_err(WorkerQuitReason::fatal_context(
"process initialize response",
))?;
let Some(session_id) = session_id else {
return Err(WorkerQuitReason::fatal(
"missing session id in initialize response",
"process initialize response",
));
let session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
Some(session_id.into())
} else {
if !self.config.allow_stateless {
return Err(WorkerQuitReason::fatal(
"missing session id in initialize response",
"process initialize response",
));
}
None
};
let session_id: Arc<str> = session_id.into();

// delete session when drop guard is dropped
{
if let Some(session_id) = &session_id {
let ct = transport_task_ct.clone();
let client = self.client.clone();
let session_id = session_id.clone();
Expand Down Expand Up @@ -322,7 +326,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
.post_message(
config.uri.clone(),
initialized_notification.message,
Some(session_id.clone()),
session_id.clone(),
None,
)
.await
Expand All @@ -340,38 +344,40 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
StreamResult(Result<(), StreamableHttpError<E>>),
}
let mut streams = tokio::task::JoinSet::new();
match self
.client
.get_stream(config.uri.clone(), session_id.clone(), None, None)
.await
{
Ok(stream) => {
let sse_stream = SseAutoReconnectStream::new(
stream,
StreamableHttpClientReconnect {
client: self.client.clone(),
session_id: session_id.clone(),
uri: config.uri.clone(),
},
self.config.retry_config.clone(),
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
transport_task_ct.child_token(),
));
tracing::debug!("got common stream");
}
Err(StreamableHttpError::SeverDoesNotSupportSse) => {
tracing::debug!("server doesn't support sse, skip common stream");
}
Err(e) => {
// fail to get common stream
tracing::error!("fail to get common stream: {e}");
return Err(WorkerQuitReason::fatal(
"fail to get general purpose event stream",
"get general purpose event stream",
));
if let Some(session_id) = &session_id {
match self
.client
.get_stream(config.uri.clone(), session_id.clone(), None, None)
.await
{
Ok(stream) => {
let sse_stream = SseAutoReconnectStream::new(
stream,
StreamableHttpClientReconnect {
client: self.client.clone(),
session_id: session_id.clone(),
uri: config.uri.clone(),
},
self.config.retry_config.clone(),
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
transport_task_ct.child_token(),
));
tracing::debug!("got common stream");
}
Err(StreamableHttpError::SeverDoesNotSupportSse) => {
tracing::debug!("server doesn't support sse, skip common stream");
}
Err(e) => {
// fail to get common stream
tracing::error!("fail to get common stream: {e}");
return Err(WorkerQuitReason::fatal(
"fail to get general purpose event stream",
"get general purpose event stream",
));
}
}
}
loop {
Expand Down Expand Up @@ -407,7 +413,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
let WorkerSendRequest { message, responder } = send_request;
let response = self
.client
.post_message(config.uri.clone(), message, Some(session_id.clone()), None)
.post_message(config.uri.clone(), message, session_id.clone(), None)
.await;
let send_result = match response {
Err(e) => Err(e),
Expand All @@ -420,20 +426,32 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
Ok(())
}
Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
let sse_stream = SseAutoReconnectStream::new(
stream,
StreamableHttpClientReconnect {
client: self.client.clone(),
session_id: session_id.clone(),
uri: config.uri.clone(),
},
self.config.retry_config.clone(),
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
transport_task_ct.child_token(),
));
if let Some(session_id) = &session_id {
let sse_stream = SseAutoReconnectStream::new(
stream,
StreamableHttpClientReconnect {
client: self.client.clone(),
session_id: session_id.clone(),
uri: config.uri.clone(),
},
self.config.retry_config.clone(),
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
transport_task_ct.child_token(),
));
} else {
let sse_stream = SseAutoReconnectStream::never_reconnect(
stream,
StreamableHttpError::<C::Error>::UnexpectedEndOfStream,
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
transport_task_ct.child_token(),
));
}
tracing::trace!("got new sse stream");
Ok(())
}
Expand Down Expand Up @@ -470,6 +488,8 @@ pub struct StreamableHttpClientTransportConfig {
pub uri: Arc<str>,
pub retry_config: Arc<dyn SseRetryPolicy>,
pub channel_buffer_capacity: usize,
/// if true, the transport will not require a session to be established
pub allow_stateless: bool,
}

impl StreamableHttpClientTransportConfig {
Expand All @@ -487,6 +507,7 @@ impl Default for StreamableHttpClientTransportConfig {
uri: "localhost".into(),
retry_config: Arc::new(ExponentialBackoff::default()),
channel_buffer_capacity: 16,
allow_stateless: true,
}
}
}