diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index e9aa78b5..486719d4 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -40,13 +40,14 @@ reqwest = { version = "0.12", default-features = false, features = [ "stream", ], optional = true } sse-stream = { version = "0.1.4", optional = true } +http = { version = "1", optional = true } url = { version = "2.4", optional = true } # For tower compatibility tower-service = { version = "0.3", optional = true } # for child process transport -process-wrap = { version = "8.2", features = ["tokio1"], optional = true} +process-wrap = { version = "8.2", features = ["tokio1"], optional = true } # for ws transport # tokio-tungstenite ={ version = "0.26", optional = true } @@ -75,7 +76,7 @@ reqwest-tls-no-provider = ["__reqwest", "reqwest?/rustls-tls-no-provider"] axum = ["dep:axum"] # SSE client -client-side-sse = ["dep:sse-stream"] +client-side-sse = ["dep:sse-stream", "dep:http"] transport-sse-client = ["client-side-sse", "transport-worker"] @@ -83,10 +84,7 @@ transport-worker = ["dep:tokio-stream"] # Streamable HTTP client -transport-streamable-http-client = [ - "client-side-sse", - "transport-worker", -] +transport-streamable-http-client = ["client-side-sse", "transport-worker"] transport-async-rw = ["tokio/io-util", "tokio-util/codec"] @@ -98,6 +96,7 @@ transport-child-process = [ ] transport-sse-server = [ "transport-async-rw", + "transport-worker", "axum", "dep:rand", "dep:tokio-stream", diff --git a/crates/rmcp/src/transport/common/auth/sse_client.rs b/crates/rmcp/src/transport/common/auth/sse_client.rs index e6603dd7..009593e1 100644 --- a/crates/rmcp/src/transport/common/auth/sse_client.rs +++ b/crates/rmcp/src/transport/common/auth/sse_client.rs @@ -1,3 +1,5 @@ +use http::Uri; + use crate::transport::{ auth::AuthClient, sse_client::{SseClient, SseTransportError}, @@ -10,7 +12,7 @@ where async fn post_message( &self, - uri: std::sync::Arc, + uri: Uri, message: crate::model::ClientJsonRpcMessage, mut auth_token: Option, ) -> Result<(), SseTransportError> { @@ -25,7 +27,7 @@ where async fn get_stream( &self, - uri: std::sync::Arc, + uri: Uri, last_event_id: Option, mut auth_token: Option, ) -> Result< diff --git a/crates/rmcp/src/transport/common/reqwest/sse_client.rs b/crates/rmcp/src/transport/common/reqwest/sse_client.rs index 99262c22..37fe7841 100644 --- a/crates/rmcp/src/transport/common/reqwest/sse_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/sse_client.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use futures::StreamExt; +use http::Uri; use reqwest::header::ACCEPT; use sse_stream::SseStream; @@ -15,11 +16,11 @@ impl SseClient for reqwest::Client { async fn post_message( &self, - uri: std::sync::Arc, + uri: Uri, message: crate::model::ClientJsonRpcMessage, auth_token: Option, ) -> Result<(), SseTransportError> { - let mut request_builder = self.post(uri.as_ref()).json(&message); + let mut request_builder = self.post(uri.to_string()).json(&message); if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); } @@ -33,7 +34,7 @@ impl SseClient for reqwest::Client { async fn get_stream( &self, - uri: std::sync::Arc, + uri: Uri, last_event_id: Option, auth_token: Option, ) -> Result< @@ -41,7 +42,7 @@ impl SseClient for reqwest::Client { SseTransportError, > { let mut request_builder = self - .get(uri.as_ref()) + .get(uri.to_string()) .header(ACCEPT, EVENT_STREAM_MIME_TYPE); if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); @@ -73,7 +74,7 @@ impl SseClientTransport { SseClientTransport::start_with_client( reqwest::Client::default(), SseClientConfig { - uri: uri.into(), + sse_endpoint: uri.into(), ..Default::default() }, ) diff --git a/crates/rmcp/src/transport/sse_client.rs b/crates/rmcp/src/transport/sse_client.rs index c8e00bea..f9d0e434 100644 --- a/crates/rmcp/src/transport/sse_client.rs +++ b/crates/rmcp/src/transport/sse_client.rs @@ -2,6 +2,7 @@ use std::{pin::Pin, sync::Arc}; use futures::{StreamExt, future::BoxFuture}; +use http::Uri; use reqwest::header::HeaderValue; use sse_stream::Error as SseError; use thiserror::Error; @@ -32,6 +33,10 @@ pub enum SseTransportError { #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] #[error("Auth error: {0}")] Auth(#[from] crate::transport::auth::AuthError), + #[error("Invalid uri: {0}")] + InvalidUri(#[from] http::uri::InvalidUri), + #[error("Invalid uri parts: {0}")] + InvalidUriParts(#[from] http::uri::InvalidUriParts), } impl From for SseTransportError { @@ -44,13 +49,13 @@ pub trait SseClient: Clone + Send + Sync + 'static { type Error: std::error::Error + Send + Sync + 'static; fn post_message( &self, - uri: Arc, + uri: Uri, message: ClientJsonRpcMessage, auth_token: Option, ) -> impl Future>> + Send + '_; fn get_stream( &self, - uri: Arc, + uri: Uri, last_event_id: Option, auth_token: Option, ) -> impl Future>> + Send + '_; @@ -58,7 +63,7 @@ pub trait SseClient: Clone + Send + Sync + 'static { struct SseClientReconnect { pub client: C, - pub uri: Arc, + pub uri: Uri, } impl SseStreamReconnect for SseClientReconnect { @@ -75,7 +80,7 @@ type ServerMessageStream = Pin { client: C, config: SseClientConfig, - post_uri: Arc, + message_endpoint: Uri, stream: Option>, } @@ -89,7 +94,7 @@ impl Transport for SseClientTransport { item: crate::service::TxJsonRpcMessage, ) -> impl Future> + Send + 'static { let client = self.client.clone(); - let uri = self.post_uri.clone(); + let uri = self.message_endpoint.clone(); async move { client.post_message(uri, item, None).await } } async fn close(&mut self) -> Result<(), Self::Error> { @@ -112,9 +117,11 @@ impl SseClientTransport { client: C, config: SseClientConfig, ) -> Result> { - let mut sse_stream = client.get_stream(config.uri.clone(), None, None).await?; - let endpoint = if let Some(endpoint) = config.use_endpoint.clone() { - endpoint + let sse_endpoint = config.sse_endpoint.as_ref().parse::()?; + + let mut sse_stream = client.get_stream(sse_endpoint.clone(), None, None).await?; + let message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() { + endpoint.parse::()? } else { // wait the endpoint event loop { @@ -125,27 +132,29 @@ impl SseClientTransport { let Some("endpoint") = sse.event.as_deref() else { continue; }; - break sse.data.unwrap_or_default(); + let sse_endpoint = sse.data.unwrap_or_default(); + break sse_endpoint.parse::()?; } }; - let post_uri: Arc = format!( - "{}/{}", - config.uri.trim_end_matches("/"), - endpoint.trim_start_matches("/") - ) - .into(); + + // sse: -> + let message_endpoint = { + let mut sse_endpoint_parts = sse_endpoint.clone().into_parts(); + sse_endpoint_parts.path_and_query = message_endpoint.into_parts().path_and_query; + Uri::from_parts(sse_endpoint_parts)? + }; let stream = Box::pin(SseAutoReconnectStream::new( sse_stream, SseClientReconnect { client: client.clone(), - uri: config.uri.clone(), + uri: sse_endpoint.clone(), }, config.retry_policy.clone(), )); Ok(Self { client, config, - post_uri, + message_endpoint, stream: Some(stream), }) } @@ -153,18 +162,29 @@ impl SseClientTransport { #[derive(Debug, Clone)] pub struct SseClientConfig { - pub uri: Arc, + /// client sse endpoint + /// + /// # How this client resolve the message endpoint + /// if sse_endpoint has this format: ``, + /// then the message endpoint will be ``. + /// + /// For example, if you config the sse_endpoint as `http://example.com/some_path/sse`, + /// and the server send the message endpoint event as `message?session_id=123`, + /// then the message endpoint will be `http://example.com/message`. + /// + /// This follow the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/zh-CN/docs/Web/API/URL/URL) + pub sse_endpoint: Arc, pub retry_policy: Arc, /// if this is settled, the client will use this endpoint to send message and skip get the endpoint event - pub use_endpoint: Option, + pub use_message_endpoint: Option, } impl Default for SseClientConfig { fn default() -> Self { Self { - uri: "".into(), + sse_endpoint: "".into(), retry_policy: Arc::new(super::common::client_side_sse::FixedInterval::default()), - use_endpoint: None, + use_message_endpoint: None, } } } diff --git a/examples/clients/src/oauth_client.rs b/examples/clients/src/oauth_client.rs index 8002e2a0..47818352 100644 --- a/examples/clients/src/oauth_client.rs +++ b/examples/clients/src/oauth_client.rs @@ -147,7 +147,7 @@ async fn main() -> Result<()> { let transport = SseClientTransport::start_with_client( client, SseClientConfig { - uri: MCP_SSE_URL.into(), + sse_endpoint: MCP_SSE_URL.into(), ..Default::default() }, )