Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ reqwest = { version = "0.12", default-features = false, features = [
"json",
"stream",
], optional = true }
sse-stream = { version = "0.2.0", optional = true }

sse-stream = { version = "0.2", optional = true }

http = { version = "1", optional = true }
url = { version = "2.4", optional = true }

Expand All @@ -57,7 +59,9 @@ axum = { version = "0.8", features = [], optional = true }
rand = { version = "0.9", optional = true }
tokio-stream = { version = "0.1", optional = true }
uuid = { version = "1", features = ["v4"], optional = true }

http-body = { version = "1", optional = true }
http-body-util = { version = "0.1", optional = true }
bytes = { version = "1", optional = true }
# macro
rmcp-macros = { version = "0.1", workspace = true, optional = true }

Expand All @@ -74,7 +78,17 @@ reqwest = ["__reqwest", "reqwest?/rustls-tls"]

reqwest-tls-no-provider = ["__reqwest", "reqwest?/rustls-tls-no-provider"]

axum = ["dep:axum"]
server-side-http = [
"uuid",
"dep:rand",
"dep:tokio-stream",
"dep:http",
"dep:http-body",
"dep:http-body-util",
"dep:bytes",
"dep:sse-stream",
"tower",
]
# SSE client
client-side-sse = ["dep:sse-stream", "dep:http"]

Expand All @@ -97,15 +111,12 @@ transport-child-process = [
transport-sse-server = [
"transport-async-rw",
"transport-worker",
"axum",
"dep:rand",
"dep:tokio-stream",
"uuid",
"server-side-http",
"dep:axum",
]
transport-streamable-http-server = [
"transport-streamable-http-server-session",
"axum",
"uuid",
"server-side-http",
]
transport-streamable-http-server-session = [
"transport-async-rw",
Expand Down
10 changes: 5 additions & 5 deletions crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ pub struct RequestContext<R: ServiceRole> {
}

/// Use this function to skip initialization process
pub async fn serve_directly<R, S, T, E, A>(
pub fn serve_directly<R, S, T, E, A>(
service: S,
transport: T,
peer_info: Option<R::PeerInfo>,
Expand All @@ -499,11 +499,11 @@ where
T: IntoTransport<R, E, A>,
E: std::error::Error + Send + Sync + 'static,
{
serve_directly_with_ct(service, transport, peer_info, Default::default()).await
serve_directly_with_ct(service, transport, peer_info, Default::default())
}

/// Use this function to skip initialization process
pub async fn serve_directly_with_ct<R, S, T, E, A>(
pub fn serve_directly_with_ct<R, S, T, E, A>(
service: S,
transport: T,
peer_info: Option<R::PeerInfo>,
Expand All @@ -516,11 +516,11 @@ where
E: std::error::Error + Send + Sync + 'static,
{
let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info);
serve_inner(service, transport, peer, peer_rx, ct).await
serve_inner(service, transport, peer, peer_rx, ct)
}

#[instrument(skip_all)]
async fn serve_inner<R, S, T, E, A>(
fn serve_inner<R, S, T, E, A>(
service: S,
transport: T,
peer: Peer<R>,
Expand Down
6 changes: 3 additions & 3 deletions crates/rmcp/src/service/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ pub async fn serve_client<S, T, E, A>(
where
S: Service<RoleClient>,
T: IntoTransport<RoleClient, E, A>,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
{
serve_client_with_ct(service, transport, Default::default()).await
}
Expand All @@ -124,7 +124,7 @@ pub async fn serve_client_with_ct<S, T, E, A>(
where
S: Service<RoleClient>,
T: IntoTransport<RoleClient, E, A>,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
{
let mut transport = transport.into_transport();
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
Expand Down Expand Up @@ -175,7 +175,7 @@ where
context: "send initialized notification".into(),
})?;
let (peer, peer_rx) = Peer::new(id_provider, Some(initialize_result));
Ok(serve_inner(service, transport, peer, peer_rx, ct).await)
Ok(serve_inner(service, transport, peer, peer_rx, ct))
}

macro_rules! method {
Expand Down
8 changes: 4 additions & 4 deletions crates/rmcp/src/service/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl<S: Service<RoleServer>> ServiceExt<RoleServer> for S {
) -> impl Future<Output = Result<RunningService<RoleServer, Self>, ServerInitializeError<E>>> + Send
where
T: IntoTransport<RoleServer, E, A>,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
Self: Sized,
{
serve_server_with_ct(self, transport, ct)
Expand All @@ -84,7 +84,7 @@ pub async fn serve_server<S, T, E, A>(
where
S: Service<RoleServer>,
T: IntoTransport<RoleServer, E, A>,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
{
serve_server_with_ct(service, transport, CancellationToken::new()).await
}
Expand Down Expand Up @@ -143,7 +143,7 @@ pub async fn serve_server_with_ct<S, T, E, A>(
where
S: Service<RoleServer>,
T: IntoTransport<RoleServer, E, A>,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
{
let mut transport = transport.into_transport();
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
Expand Down Expand Up @@ -212,7 +212,7 @@ where
};
let _ = service.handle_notification(notification).await;
// Continue processing service
Ok(serve_inner(service, transport, peer, peer_rx, ct).await)
Ok(serve_inner(service, transport, peer, peer_rx, ct))
}

macro_rules! method {
Expand Down
74 changes: 70 additions & 4 deletions crates/rmcp/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//! | transport | client | server |
//! |:-: |:-: |:-: |
//! | std IO | [`child_process::TokioChildProcess`] | [`io::stdio`] |
//! | streamable http | [`streamable_http_client::StreamableHttpClientTransport`] | [`streamable_http_server::session::create_session`] |
//! | streamable http | [`streamable_http_client::StreamableHttpClientTransport`] | [`streamable_http_server::StreamableHttpService`] |
//! | sse | [`sse_client::SseClientTransport`] | [`sse_server::SseServer`] |
//!
//!## Helper Transport Types
Expand Down Expand Up @@ -64,6 +64,8 @@
//! }
//! ```

use std::sync::Arc;

use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage};

pub mod sink_stream;
Expand Down Expand Up @@ -122,7 +124,7 @@ pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, Authorized
pub mod streamable_http_server;
#[cfg(feature = "transport-streamable-http-server")]
#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))]
pub use streamable_http_server::axum::StreamableHttpServer;
pub use streamable_http_server::tower::{StreamableHttpServerConfig, StreamableHttpService};

#[cfg(feature = "transport-streamable-http-client")]
#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client")))]
Expand All @@ -138,7 +140,7 @@ pub trait Transport<R>: Send
where
R: ServiceRole,
{
type Error;
type Error: std::error::Error + Send + Sync + 'static;
/// Send a message to the transport
///
/// Notice that the future returned by this function should be `Send` and `'static`.
Expand Down Expand Up @@ -169,9 +171,73 @@ impl<R, T, E> IntoTransport<R, E, TransportAdapterIdentity> for T
where
T: Transport<R, Error = E> + Send + 'static,
R: ServiceRole,
E: std::error::Error + Send + 'static,
E: std::error::Error + Send + Sync + 'static,
{
fn into_transport(self) -> impl Transport<R, Error = E> + 'static {
self
}
}

/// A transport that can send a single message and then close itself
pub struct OneshotTransport<R>
where
R: ServiceRole,
{
message: Option<RxJsonRpcMessage<R>>,
sender: tokio::sync::mpsc::Sender<TxJsonRpcMessage<R>>,
finished_signal: Arc<tokio::sync::Notify>,
}

impl<R> OneshotTransport<R>
where
R: ServiceRole,
{
pub fn new(
message: RxJsonRpcMessage<R>,
) -> (Self, tokio::sync::mpsc::Receiver<TxJsonRpcMessage<R>>) {
let (sender, receiver) = tokio::sync::mpsc::channel(16);
(
Self {
message: Some(message),
sender,
finished_signal: Arc::new(tokio::sync::Notify::new()),
},
receiver,
)
}
}

impl<R> Transport<R> for OneshotTransport<R>
where
R: ServiceRole,
{
type Error = tokio::sync::mpsc::error::SendError<TxJsonRpcMessage<R>>;

fn send(
&mut self,
item: TxJsonRpcMessage<R>,
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
let sender = self.sender.clone();
let terminate = matches!(item, TxJsonRpcMessage::<R>::Response(_));
let signal = self.finished_signal.clone();
async move {
sender.send(item).await?;
if terminate {
signal.notify_waiters();
}
Ok(())
}
}

async fn receive(&mut self) -> Option<RxJsonRpcMessage<R>> {
if self.message.is_none() {
self.finished_signal.notified().await;
}
self.message.take()
}

fn close(&mut self) -> impl Future<Output = Result<(), Self::Error>> + Send {
self.message.take();
std::future::ready(Ok(()))
}
}
2 changes: 1 addition & 1 deletion crates/rmcp/src/transport/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
feature = "transport-streamable-http-server",
feature = "transport-sse-server"
))]
pub mod axum;
pub mod server_side_http;

pub mod http_header;

Expand Down
9 changes: 0 additions & 9 deletions crates/rmcp/src/transport/common/axum.rs

This file was deleted.

Loading