Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ reqwest = { version = "0.12", default-features = false, features = [
"stream",
], optional = true }
sse-stream = { version = "0.1.4", optional = true }
http = { version = "1", optional = true }
url = { version = "2.4", optional = true }

# For tower compatibility
tower-service = { version = "0.3", optional = true }

# for child process transport
process-wrap = { version = "8.2", features = ["tokio1"], optional = true}
process-wrap = { version = "8.2", features = ["tokio1"], optional = true }

# for ws transport
# tokio-tungstenite ={ version = "0.26", optional = true }
Expand Down Expand Up @@ -75,18 +76,15 @@ reqwest-tls-no-provider = ["__reqwest", "reqwest?/rustls-tls-no-provider"]

axum = ["dep:axum"]
# SSE client
client-side-sse = ["dep:sse-stream"]
client-side-sse = ["dep:sse-stream", "dep:http"]

transport-sse-client = ["client-side-sse", "transport-worker"]

transport-worker = ["dep:tokio-stream"]


# Streamable HTTP client
transport-streamable-http-client = [
"client-side-sse",
"transport-worker",
]
transport-streamable-http-client = ["client-side-sse", "transport-worker"]


transport-async-rw = ["tokio/io-util", "tokio-util/codec"]
Expand All @@ -98,6 +96,7 @@ transport-child-process = [
]
transport-sse-server = [
"transport-async-rw",
"transport-worker",
"axum",
"dep:rand",
"dep:tokio-stream",
Expand Down
6 changes: 4 additions & 2 deletions crates/rmcp/src/transport/common/auth/sse_client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use http::Uri;

use crate::transport::{
auth::AuthClient,
sse_client::{SseClient, SseTransportError},
Expand All @@ -10,7 +12,7 @@ where

async fn post_message(
&self,
uri: std::sync::Arc<str>,
uri: Uri,
message: crate::model::ClientJsonRpcMessage,
mut auth_token: Option<String>,
) -> Result<(), SseTransportError<Self::Error>> {
Expand All @@ -25,7 +27,7 @@ where

async fn get_stream(
&self,
uri: std::sync::Arc<str>,
uri: Uri,
last_event_id: Option<String>,
mut auth_token: Option<String>,
) -> Result<
Expand Down
11 changes: 6 additions & 5 deletions crates/rmcp/src/transport/common/reqwest/sse_client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use futures::StreamExt;
use http::Uri;
use reqwest::header::ACCEPT;
use sse_stream::SseStream;

Expand All @@ -15,11 +16,11 @@ impl SseClient for reqwest::Client {

async fn post_message(
&self,
uri: std::sync::Arc<str>,
uri: Uri,
message: crate::model::ClientJsonRpcMessage,
auth_token: Option<String>,
) -> Result<(), SseTransportError<Self::Error>> {
let mut request_builder = self.post(uri.as_ref()).json(&message);
let mut request_builder = self.post(uri.to_string()).json(&message);
if let Some(auth_header) = auth_token {
request_builder = request_builder.bearer_auth(auth_header);
}
Expand All @@ -33,15 +34,15 @@ impl SseClient for reqwest::Client {

async fn get_stream(
&self,
uri: std::sync::Arc<str>,
uri: Uri,
last_event_id: Option<String>,
auth_token: Option<String>,
) -> Result<
crate::transport::common::client_side_sse::BoxedSseResponse,
SseTransportError<Self::Error>,
> {
let mut request_builder = self
.get(uri.as_ref())
.get(uri.to_string())
.header(ACCEPT, EVENT_STREAM_MIME_TYPE);
if let Some(auth_header) = auth_token {
request_builder = request_builder.bearer_auth(auth_header);
Expand Down Expand Up @@ -73,7 +74,7 @@ impl SseClientTransport<reqwest::Client> {
SseClientTransport::start_with_client(
reqwest::Client::default(),
SseClientConfig {
uri: uri.into(),
sse_endpoint: uri.into(),
..Default::default()
},
)
Expand Down
62 changes: 41 additions & 21 deletions crates/rmcp/src/transport/sse_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use std::{pin::Pin, sync::Arc};

use futures::{StreamExt, future::BoxFuture};
use http::Uri;
use reqwest::header::HeaderValue;
use sse_stream::Error as SseError;
use thiserror::Error;
Expand Down Expand Up @@ -32,6 +33,10 @@ pub enum SseTransportError<E: std::error::Error + Send + Sync + 'static> {
#[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
#[error("Auth error: {0}")]
Auth(#[from] crate::transport::auth::AuthError),
#[error("Invalid uri: {0}")]
InvalidUri(#[from] http::uri::InvalidUri),
#[error("Invalid uri parts: {0}")]
InvalidUriParts(#[from] http::uri::InvalidUriParts),
}

impl From<reqwest::Error> for SseTransportError<reqwest::Error> {
Expand All @@ -44,21 +49,21 @@ pub trait SseClient: Clone + Send + Sync + 'static {
type Error: std::error::Error + Send + Sync + 'static;
fn post_message(
&self,
uri: Arc<str>,
uri: Uri,
message: ClientJsonRpcMessage,
auth_token: Option<String>,
) -> impl Future<Output = Result<(), SseTransportError<Self::Error>>> + Send + '_;
fn get_stream(
&self,
uri: Arc<str>,
uri: Uri,
last_event_id: Option<String>,
auth_token: Option<String>,
) -> impl Future<Output = Result<BoxedSseResponse, SseTransportError<Self::Error>>> + Send + '_;
}

struct SseClientReconnect<C> {
pub client: C,
pub uri: Arc<str>,
pub uri: Uri,
}

impl<C: SseClient> SseStreamReconnect for SseClientReconnect<C> {
Expand All @@ -75,7 +80,7 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
pub struct SseClientTransport<C: SseClient> {
client: C,
config: SseClientConfig,
post_uri: Arc<str>,
message_endpoint: Uri,
stream: Option<ServerMessageStream<C>>,
}

Expand All @@ -89,7 +94,7 @@ impl<C: SseClient> Transport<RoleClient> for SseClientTransport<C> {
item: crate::service::TxJsonRpcMessage<RoleClient>,
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
let client = self.client.clone();
let uri = self.post_uri.clone();
let uri = self.message_endpoint.clone();
async move { client.post_message(uri, item, None).await }
}
async fn close(&mut self) -> Result<(), Self::Error> {
Expand All @@ -112,9 +117,11 @@ impl<C: SseClient> SseClientTransport<C> {
client: C,
config: SseClientConfig,
) -> Result<Self, SseTransportError<C::Error>> {
let mut sse_stream = client.get_stream(config.uri.clone(), None, None).await?;
let endpoint = if let Some(endpoint) = config.use_endpoint.clone() {
endpoint
let sse_endpoint = config.sse_endpoint.as_ref().parse::<http::Uri>()?;

let mut sse_stream = client.get_stream(sse_endpoint.clone(), None, None).await?;
let message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() {
endpoint.parse::<http::Uri>()?
} else {
// wait the endpoint event
loop {
Expand All @@ -125,46 +132,59 @@ impl<C: SseClient> SseClientTransport<C> {
let Some("endpoint") = sse.event.as_deref() else {
continue;
};
break sse.data.unwrap_or_default();
let sse_endpoint = sse.data.unwrap_or_default();
break sse_endpoint.parse::<http::Uri>()?;
}
};
let post_uri: Arc<str> = format!(
"{}/{}",
config.uri.trim_end_matches("/"),
endpoint.trim_start_matches("/")
)
.into();

// sse: <authority><sse_pq> -> <authority><message_pq>
let message_endpoint = {
let mut sse_endpoint_parts = sse_endpoint.clone().into_parts();
sse_endpoint_parts.path_and_query = message_endpoint.into_parts().path_and_query;
Uri::from_parts(sse_endpoint_parts)?
};
let stream = Box::pin(SseAutoReconnectStream::new(
sse_stream,
SseClientReconnect {
client: client.clone(),
uri: config.uri.clone(),
uri: sse_endpoint.clone(),
},
config.retry_policy.clone(),
));
Ok(Self {
client,
config,
post_uri,
message_endpoint,
stream: Some(stream),
})
}
}

#[derive(Debug, Clone)]
pub struct SseClientConfig {
pub uri: Arc<str>,
/// client sse endpoint
///
/// # How this client resolve the message endpoint
/// if sse_endpoint has this format: `<schema><authority?><sse_pq>`,
/// then the message endpoint will be `<schema><authority?><message_pq>`.
///
/// For example, if you config the sse_endpoint as `http://example.com/some_path/sse`,
/// and the server send the message endpoint event as `message?session_id=123`,
/// then the message endpoint will be `http://example.com/message`.
///
/// This follow the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/zh-CN/docs/Web/API/URL/URL)
pub sse_endpoint: Arc<str>,
pub retry_policy: Arc<dyn SseRetryPolicy>,
/// if this is settled, the client will use this endpoint to send message and skip get the endpoint event
pub use_endpoint: Option<String>,
pub use_message_endpoint: Option<String>,
}

impl Default for SseClientConfig {
fn default() -> Self {
Self {
uri: "".into(),
sse_endpoint: "".into(),
retry_policy: Arc::new(super::common::client_side_sse::FixedInterval::default()),
use_endpoint: None,
use_message_endpoint: None,
}
}
}
2 changes: 1 addition & 1 deletion examples/clients/src/oauth_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async fn main() -> Result<()> {
let transport = SseClientTransport::start_with_client(
client,
SseClientConfig {
uri: MCP_SSE_URL.into(),
sse_endpoint: MCP_SSE_URL.into(),
..Default::default()
},
)
Expand Down
Loading