From 59ba72ab338928872e02d0e93a58a40143d88d7c Mon Sep 17 00:00:00 2001 From: 4t145 Date: Tue, 27 May 2025 14:06:23 +0800 Subject: [PATCH 1/4] refactor: streamable http server as tower service --- crates/rmcp/Cargo.toml | 29 +- crates/rmcp/src/service.rs | 10 +- crates/rmcp/src/service/client.rs | 6 +- crates/rmcp/src/service/server.rs | 8 +- crates/rmcp/src/transport.rs | 74 +- crates/rmcp/src/transport/common.rs | 2 +- crates/rmcp/src/transport/common/axum.rs | 9 - .../src/transport/common/sever_side_http.rs | 140 +++ crates/rmcp/src/transport/sink_stream.rs | 5 +- crates/rmcp/src/transport/sse_server.rs | 4 +- .../src/transport/streamable_http_server.rs | 9 +- .../transport/streamable_http_server/axum.rs | 349 ------- .../streamable_http_server/session.rs | 808 +--------------- .../streamable_http_server/session/local.rs | 880 ++++++++++++++++++ .../streamable_http_server/session/never.rs | 100 ++ .../transport/streamable_http_server/tower.rs | 406 ++++++++ crates/rmcp/tests/test_with_js.rs | 33 +- .../tests/test_with_js/streamable_client.js | 2 +- examples/servers/Cargo.toml | 27 +- examples/servers/README.md | 9 +- .../src/counter_hyper_streamable_http.rs | 36 + examples/servers/src/counter_streamhttp.rs | 19 +- 22 files changed, 1784 insertions(+), 1181 deletions(-) delete mode 100644 crates/rmcp/src/transport/common/axum.rs create mode 100644 crates/rmcp/src/transport/common/sever_side_http.rs delete mode 100644 crates/rmcp/src/transport/streamable_http_server/axum.rs create mode 100644 crates/rmcp/src/transport/streamable_http_server/session/local.rs create mode 100644 crates/rmcp/src/transport/streamable_http_server/session/never.rs create mode 100644 crates/rmcp/src/transport/streamable_http_server/tower.rs create mode 100644 examples/servers/src/counter_hyper_streamable_http.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 0381901c..1cd3db41 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -39,7 +39,9 @@ reqwest = { version = "0.12", default-features = false, features = [ "json", "stream", ], optional = true } -sse-stream = { version = "0.2.0", optional = true } + +sse-stream = { version = "0.2", optional = true } + http = { version = "1", optional = true } url = { version = "2.4", optional = true } @@ -57,7 +59,9 @@ axum = { version = "0.8", features = [], optional = true } rand = { version = "0.9", optional = true } tokio-stream = { version = "0.1", optional = true } uuid = { version = "1", features = ["v4"], optional = true } - +http-body = { version = "1", optional = true } +http-body-util = { version = "0.1", optional = true } +bytes = { version = "1", optional = true } # macro rmcp-macros = { version = "0.1", workspace = true, optional = true } @@ -74,7 +78,17 @@ reqwest = ["__reqwest", "reqwest?/rustls-tls"] reqwest-tls-no-provider = ["__reqwest", "reqwest?/rustls-tls-no-provider"] -axum = ["dep:axum"] +server-side-http = [ + "uuid", + "dep:rand", + "dep:tokio-stream", + "dep:http", + "dep:http-body", + "dep:http-body-util", + "dep:bytes", + "dep:sse-stream", + "tower", +] # SSE client client-side-sse = ["dep:sse-stream", "dep:http"] @@ -97,15 +111,12 @@ transport-child-process = [ transport-sse-server = [ "transport-async-rw", "transport-worker", - "axum", - "dep:rand", - "dep:tokio-stream", - "uuid", + "server-side-http", + "dep:axum", ] transport-streamable-http-server = [ "transport-streamable-http-server-session", - "axum", - "uuid", + "server-side-http", ] transport-streamable-http-server-session = [ "transport-async-rw", diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 11b7e67d..8fc05511 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -488,7 +488,7 @@ pub struct RequestContext { } /// Use this function to skip initialization process -pub async fn serve_directly( +pub fn serve_directly( service: S, transport: T, peer_info: Option, @@ -499,11 +499,11 @@ where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, { - serve_directly_with_ct(service, transport, peer_info, Default::default()).await + serve_directly_with_ct(service, transport, peer_info, Default::default()) } /// Use this function to skip initialization process -pub async fn serve_directly_with_ct( +pub fn serve_directly_with_ct( service: S, transport: T, peer_info: Option, @@ -516,11 +516,11 @@ where E: std::error::Error + Send + Sync + 'static, { let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info); - serve_inner(service, transport, peer, peer_rx, ct).await + serve_inner(service, transport, peer, peer_rx, ct) } #[instrument(skip_all)] -async fn serve_inner( +fn serve_inner( service: S, transport: T, peer: Peer, diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index f44d176c..53a3e203 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -111,7 +111,7 @@ pub async fn serve_client( where S: Service, T: IntoTransport, - E: std::error::Error + From + Send + Sync + 'static, + E: std::error::Error + Send + Sync + 'static, { serve_client_with_ct(service, transport, Default::default()).await } @@ -124,7 +124,7 @@ pub async fn serve_client_with_ct( where S: Service, T: IntoTransport, - E: std::error::Error + From + Send + Sync + 'static, + E: std::error::Error + Send + Sync + 'static, { let mut transport = transport.into_transport(); let id_provider = >::default(); @@ -175,7 +175,7 @@ where context: "send initialized notification".into(), })?; let (peer, peer_rx) = Peer::new(id_provider, Some(initialize_result)); - Ok(serve_inner(service, transport, peer, peer_rx, ct).await) + Ok(serve_inner(service, transport, peer, peer_rx, ct)) } macro_rules! method { diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index 6825282b..f2c5ca67 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -70,7 +70,7 @@ impl> ServiceExt for S { ) -> impl Future, ServerInitializeError>> + Send where T: IntoTransport, - E: std::error::Error + From + Send + Sync + 'static, + E: std::error::Error + Send + Sync + 'static, Self: Sized, { serve_server_with_ct(self, transport, ct) @@ -84,7 +84,7 @@ pub async fn serve_server( where S: Service, T: IntoTransport, - E: std::error::Error + From + Send + Sync + 'static, + E: std::error::Error + Send + Sync + 'static, { serve_server_with_ct(service, transport, CancellationToken::new()).await } @@ -143,7 +143,7 @@ pub async fn serve_server_with_ct( where S: Service, T: IntoTransport, - E: std::error::Error + From + Send + Sync + 'static, + E: std::error::Error + Send + Sync + 'static, { let mut transport = transport.into_transport(); let id_provider = >::default(); @@ -212,7 +212,7 @@ where }; let _ = service.handle_notification(notification).await; // Continue processing service - Ok(serve_inner(service, transport, peer, peer_rx, ct).await) + Ok(serve_inner(service, transport, peer, peer_rx, ct)) } macro_rules! method { diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index ec1b03ce..20e6ce75 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -7,7 +7,7 @@ //! | transport | client | server | //! |:-: |:-: |:-: | //! | std IO | [`child_process::TokioChildProcess`] | [`io::stdio`] | -//! | streamable http | [`streamable_http_client::StreamableHttpClientTransport`] | [`streamable_http_server::session::create_session`] | +//! | streamable http | [`streamable_http_client::StreamableHttpClientTransport`] | [`streamable_http_server::StreamableHttpService`] | //! | sse | [`sse_client::SseClientTransport`] | [`sse_server::SseServer`] | //! //!## Helper Transport Types @@ -64,6 +64,8 @@ //! } //! ``` +use std::sync::Arc; + use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage}; pub mod sink_stream; @@ -122,7 +124,7 @@ pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, Authorized pub mod streamable_http_server; #[cfg(feature = "transport-streamable-http-server")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] -pub use streamable_http_server::axum::StreamableHttpServer; +pub use streamable_http_server::tower::{StreamableHttpServerConfig, StreamableHttpService}; #[cfg(feature = "transport-streamable-http-client")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client")))] @@ -138,7 +140,7 @@ pub trait Transport: Send where R: ServiceRole, { - type Error; + type Error: std::error::Error + Send + Sync + 'static; /// Send a message to the transport /// /// Notice that the future returned by this function should be `Send` and `'static`. @@ -169,9 +171,73 @@ impl IntoTransport for T where T: Transport + Send + 'static, R: ServiceRole, - E: std::error::Error + Send + 'static, + E: std::error::Error + Send + Sync + 'static, { fn into_transport(self) -> impl Transport + 'static { self } } + +/// A transport that can send a single message and then close itself +pub struct OneshotTransport +where + R: ServiceRole, +{ + message: Option>, + sender: tokio::sync::mpsc::Sender>, + finished_signal: Arc, +} + +impl OneshotTransport +where + R: ServiceRole, +{ + pub fn new( + message: RxJsonRpcMessage, + ) -> (Self, tokio::sync::mpsc::Receiver>) { + let (sender, receiver) = tokio::sync::mpsc::channel(16); + ( + Self { + message: Some(message), + sender, + finished_signal: Arc::new(tokio::sync::Notify::new()), + }, + receiver, + ) + } +} + +impl Transport for OneshotTransport +where + R: ServiceRole, +{ + type Error = tokio::sync::mpsc::error::SendError>; + + fn send( + &mut self, + item: TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + let sender = self.sender.clone(); + let terminate = matches!(item, TxJsonRpcMessage::::Response(_)); + let signal = self.finished_signal.clone(); + async move { + sender.send(item).await?; + if terminate { + signal.notify_waiters(); + } + Ok(()) + } + } + + async fn receive(&mut self) -> Option> { + if self.message.is_none() { + self.finished_signal.notified().await; + } + self.message.take() + } + + fn close(&mut self) -> impl Future> + Send { + self.message.take(); + std::future::ready(Ok(())) + } +} diff --git a/crates/rmcp/src/transport/common.rs b/crates/rmcp/src/transport/common.rs index aa15f900..03c78266 100644 --- a/crates/rmcp/src/transport/common.rs +++ b/crates/rmcp/src/transport/common.rs @@ -2,7 +2,7 @@ feature = "transport-streamable-http-server", feature = "transport-sse-server" ))] -pub mod axum; +pub mod sever_side_http; pub mod http_header; diff --git a/crates/rmcp/src/transport/common/axum.rs b/crates/rmcp/src/transport/common/axum.rs deleted file mode 100644 index a2611575..00000000 --- a/crates/rmcp/src/transport/common/axum.rs +++ /dev/null @@ -1,9 +0,0 @@ -use std::{sync::Arc, time::Duration}; - -pub type SessionId = Arc; - -pub fn session_id() -> SessionId { - uuid::Uuid::new_v4().to_string().into() -} - -pub const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15); diff --git a/crates/rmcp/src/transport/common/sever_side_http.rs b/crates/rmcp/src/transport/common/sever_side_http.rs new file mode 100644 index 00000000..7ff5f5c1 --- /dev/null +++ b/crates/rmcp/src/transport/common/sever_side_http.rs @@ -0,0 +1,140 @@ +use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration}; + +use bytes::{Buf, Bytes}; +use http::Response; +use http_body::Body; +use http_body_util::{BodyExt, Empty, Full, combinators::UnsyncBoxBody}; +use sse_stream::{KeepAlive, Sse, SseBody}; + +use super::http_header::EVENT_STREAM_MIME_TYPE; +use crate::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; + +pub type SessionId = Arc; + +pub fn session_id() -> SessionId { + uuid::Uuid::new_v4().to_string().into() +} + +pub const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15); + +pub(crate) type BoxResponse = Response>; + +pub(crate) fn accepted_response() -> Response> { + Response::builder() + .status(http::StatusCode::ACCEPTED) + .body(Empty::new().boxed_unsync()) + .expect("valid response") +} +pin_project_lite::pin_project! { + struct TokioTimer { + #[pin] + sleep: tokio::time::Sleep, + } +} +impl Future for TokioTimer { + type Output = (); + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + this.sleep.poll(cx) + } +} +impl sse_stream::Timer for TokioTimer { + fn from_duration(duration: Duration) -> Self { + Self { + sleep: tokio::time::sleep(duration), + } + } + + fn reset(self: std::pin::Pin<&mut Self>, when: std::time::Instant) { + let this = self.project(); + this.sleep.reset(tokio::time::Instant::from_std(when)); + } +} + +#[derive(Debug, Clone)] +pub struct ServerSseMessage { + pub event_id: Option, + pub message: Arc, +} + +pub(crate) fn sse_stream_response( + stream: impl futures::Stream + Send + 'static, + keep_alive: Option, +) -> Response> { + use futures::StreamExt; + let stream = SseBody::new(stream.map(|message| { + let data = serde_json::to_string(&message.message).expect("valid message"); + let mut sse = Sse::default().data(data); + sse.id = message.event_id; + Result::::Ok(sse) + })); + let stream = match keep_alive { + Some(duration) => stream + .with_keep_alive::(KeepAlive::new().interval(duration)) + .boxed_unsync(), + None => stream.boxed_unsync(), + }; + Response::builder() + .status(http::StatusCode::OK) + .header(http::header::CONTENT_TYPE, EVENT_STREAM_MIME_TYPE) + .header(http::header::CACHE_CONTROL, "no-cache") + .body(stream) + .expect("valid response") +} + +pub(crate) const fn internal_error_response( + context: &str, +) -> impl FnOnce(E) -> Response> { + move |error| { + tracing::error!("Internal server error when {context}: {error}"); + Response::builder() + .status(http::StatusCode::INTERNAL_SERVER_ERROR) + .body( + Full::new(Bytes::from(format!( + "Encounter an error when {context}: {error}" + ))) + .boxed_unsync(), + ) + .expect("valid response") + } +} + +pub(crate) async fn expect_json( + body: B, +) -> Result>> +where + B: Body + Send + 'static, + B::Error: Display, +{ + match body.collect().await { + Ok(bytes) => { + match serde_json::from_reader::<_, ClientJsonRpcMessage>(bytes.aggregate().reader()) { + Ok(message) => Ok(message), + Err(e) => { + let response = Response::builder() + .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE) + .body( + Full::new(Bytes::from(format!("fail to deserialize request body {e}"))) + .boxed_unsync(), + ) + .expect("valid response"); + Err(response) + } + } + } + Err(e) => { + let response = Response::builder() + .status(http::StatusCode::INTERNAL_SERVER_ERROR) + .body( + Full::new(Bytes::from(format!("Failed to read request body: {e}"))) + .boxed_unsync(), + ) + .expect("valid response"); + Err(response) + } + } +} diff --git a/crates/rmcp/src/transport/sink_stream.rs b/crates/rmcp/src/transport/sink_stream.rs index 7c0cdaa5..f3174392 100644 --- a/crates/rmcp/src/transport/sink_stream.rs +++ b/crates/rmcp/src/transport/sink_stream.rs @@ -24,6 +24,7 @@ impl Transport for SinkStreamTransport where St: Send + Stream> + Unpin, Si: Send + Sink> + Unpin + 'static, + Si::Error: std::error::Error + Send + Sync + 'static, { type Error = Si::Error; @@ -56,7 +57,7 @@ where Role: ServiceRole, Si: Send + Sink> + Unpin + 'static, St: Send + Stream> + Unpin + 'static, - Si::Error: std::error::Error + Send + 'static, + Si::Error: std::error::Error + Send + Sync + 'static, { fn into_transport(self) -> impl Transport + 'static { SinkStreamTransport::new(self.0, self.1) @@ -68,7 +69,7 @@ impl IntoTransport for where Role: ServiceRole, S: Sink> + Stream> + Send + 'static, - S::Error: std::error::Error + Send + 'static, + S::Error: std::error::Error + Send + Sync + 'static, { fn into_transport(self) -> impl Transport + 'static { use futures::StreamExt; diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server.rs index 4e3e5d43..0fc12b2e 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server.rs @@ -19,7 +19,7 @@ use crate::{ RoleServer, Service, model::ClientJsonRpcMessage, service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct}, - transport::common::axum::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id}, + transport::common::sever_side_http::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id}, }; type TxStore = @@ -313,7 +313,7 @@ impl SseServer { let service = service_provider(); let ct = self.config.ct.child_token(); tokio::spawn(async move { - let server = serve_directly_with_ct(service, transport, None, ct).await; + let server = serve_directly_with_ct(service, transport, None, ct); server.waiting().await?; tokio::io::Result::Ok(()) }); diff --git a/crates/rmcp/src/transport/streamable_http_server.rs b/crates/rmcp/src/transport/streamable_http_server.rs index 930de8ab..733fc5e5 100644 --- a/crates/rmcp/src/transport/streamable_http_server.rs +++ b/crates/rmcp/src/transport/streamable_http_server.rs @@ -1,5 +1,8 @@ +pub mod session; #[cfg(feature = "transport-streamable-http-server")] #[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] -pub mod axum; -pub mod session; -pub use session::{SessionConfig, create_session}; +pub mod tower; +pub use session::{SessionId, SessionManager}; +#[cfg(feature = "transport-streamable-http-server")] +#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))] +pub use tower::{StreamableHttpServerConfig, StreamableHttpService}; diff --git a/crates/rmcp/src/transport/streamable_http_server/axum.rs b/crates/rmcp/src/transport/streamable_http_server/axum.rs deleted file mode 100644 index 016dee6b..00000000 --- a/crates/rmcp/src/transport/streamable_http_server/axum.rs +++ /dev/null @@ -1,349 +0,0 @@ -use std::{ - collections::HashMap, - io, - net::{Ipv6Addr, SocketAddr, SocketAddrV6}, - sync::Arc, - time::Duration, -}; - -use axum::{ - Json, Router, - extract::State, - http::{HeaderMap, HeaderValue, StatusCode, request::Parts}, - response::{ - IntoResponse, Response, - sse::{Event, KeepAlive, Sse}, - }, - routing::get, -}; -use futures::Stream; -use tokio_stream::wrappers::ReceiverStream; -use tokio_util::sync::CancellationToken; -use tracing::Instrument; - -use super::session::{EventId, SessionHandle, SessionWorker, StreamableHttpMessageReceiver}; -use crate::{ - RoleServer, Service, - model::ClientJsonRpcMessage, - transport::common::{ - axum::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id}, - http_header::{HEADER_LAST_EVENT_ID, HEADER_SESSION_ID}, - }, -}; -type SessionManager = Arc>>; - -#[derive(Clone)] -struct App { - session_manager: SessionManager, - transport_tx: tokio::sync::mpsc::UnboundedSender, - sse_ping_interval: Duration, -} - -impl App { - pub fn new( - sse_ping_interval: Duration, - ) -> (Self, tokio::sync::mpsc::UnboundedReceiver) { - let (transport_tx, transport_rx) = tokio::sync::mpsc::unbounded_channel(); - ( - Self { - session_manager: Default::default(), - transport_tx, - sse_ping_interval, - }, - transport_rx, - ) - } -} - -fn receiver_as_stream( - receiver: StreamableHttpMessageReceiver, -) -> impl Stream> { - use futures::StreamExt; - ReceiverStream::new(receiver.inner).map(|message| { - match serde_json::to_string(&message.message) { - Ok(bytes) => Ok(Event::default() - .event("message") - .data(&bytes) - .id(message.event_id.to_string())), - Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)), - } - }) -} - -async fn post_handler( - State(app): State, - parts: Parts, - Json(mut message): Json, -) -> Result { - use futures::StreamExt; - if let Some(session_id) = parts.headers.get(HEADER_SESSION_ID).cloned() { - let session_id = session_id - .to_str() - .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()).into_response())?; - tracing::debug!(session_id, ?message, "new client message"); - let handle = { - let sm = app.session_manager.read().await; - let session = sm - .get(session_id) - .ok_or((StatusCode::NOT_FOUND, "session not found").into_response())?; - session.clone() - }; - // inject request part - message.insert_extension(parts); - match &message { - ClientJsonRpcMessage::Request(_) | ClientJsonRpcMessage::BatchRequest(_) => { - let receiver = handle.establish_request_wise_channel().await.map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("fail to to establish request channel: {e}"), - ) - .into_response() - })?; - let http_request_id = receiver.http_request_id; - if let Err(push_err) = handle.push_message(message, http_request_id).await { - tracing::error!(session_id, ?push_err, "push message error"); - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - format!("fail to push message: {push_err}"), - ) - .into_response()); - } - let stream = - ReceiverStream::new(receiver.inner).map(|message| match serde_json::to_string( - &message.message, - ) { - Ok(bytes) => Ok(Event::default() - .event("message") - .data(&bytes) - .id(message.event_id.to_string())), - Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)), - }); - Ok(Sse::new(stream) - .keep_alive(KeepAlive::new().interval(app.sse_ping_interval)) - .into_response()) - } - _ => { - let result = handle.push_message(message, None).await; - if result.is_err() { - Err((StatusCode::GONE, "session terminated").into_response()) - } else { - Ok(StatusCode::ACCEPTED.into_response()) - } - } - } - } else { - // expect initialize message - let session_id = session_id(); - // inject request part - message.insert_extension(parts); - let (session, transport) = - super::session::create_session(session_id.clone(), Default::default()); - let Ok(_) = app.transport_tx.send(transport) else { - return Err((StatusCode::GONE, "session terminated").into_response()); - }; - - let response = session.initialize(message).await.map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("fail to initialize: {e}"), - ) - .into_response() - })?; - let mut response = Json(response).into_response(); - response.headers_mut().insert( - HEADER_SESSION_ID, - HeaderValue::from_bytes(session_id.as_bytes()).expect("should be valid header value"), - ); - app.session_manager - .write() - .await - .insert(session_id, session); - Ok(response) - } -} - -async fn get_handler( - State(app): State, - header_map: HeaderMap, -) -> Result>>, Response> { - let session_id = header_map - .get(HEADER_SESSION_ID) - .and_then(|v| v.to_str().ok()); - if let Some(session_id) = session_id { - let last_event_id = header_map - .get(HEADER_LAST_EVENT_ID) - .and_then(|v| v.to_str().ok()); - let session = { - let sm = app.session_manager.read().await; - sm.get(session_id) - .ok_or_else(|| { - ( - StatusCode::NOT_FOUND, - format!("session {session_id} not found"), - ) - .into_response() - })? - .clone() - }; - match last_event_id { - Some(last_event_id) => { - let last_event_id = last_event_id.parse::().map_err(|e| { - (StatusCode::BAD_REQUEST, format!("invalid event_id {e}")).into_response() - })?; - let receiver = session.resume(last_event_id).await.map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("resume error {e}"), - ) - .into_response() - })?; - let stream = receiver_as_stream(receiver); - Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(app.sse_ping_interval))) - } - None => { - let receiver = session.establish_common_channel().await.map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("establish common channel error {e}"), - ) - .into_response() - })?; - let stream = receiver_as_stream(receiver); - Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(app.sse_ping_interval))) - } - } - } else { - Err((StatusCode::BAD_REQUEST, "missing session id").into_response()) - } -} - -async fn delete_handler( - State(app): State, - header_map: HeaderMap, -) -> Result { - if let Some(session_id) = header_map.get(HEADER_SESSION_ID) { - let session_id = session_id - .to_str() - .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()).into_response())?; - let session = { - let mut sm = app.session_manager.write().await; - sm.remove(session_id) - .ok_or((StatusCode::NOT_FOUND, "session not found").into_response())? - }; - session.close().await.map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("fail to cancel session {session_id}: tokio join error: {e}"), - ) - .into_response() - })?; - tracing::debug!(session_id, "session deleted"); - Ok(StatusCode::ACCEPTED) - } else { - Err((StatusCode::BAD_REQUEST, "missing session id").into_response()) - } -} - -#[derive(Debug, Clone)] -pub struct StreamableHttpServerConfig { - pub bind: SocketAddr, - pub path: String, - pub ct: CancellationToken, - pub sse_keep_alive: Option, -} -impl Default for StreamableHttpServerConfig { - fn default() -> Self { - Self { - bind: SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 80, 0, 0)), - path: "/".to_string(), - ct: CancellationToken::new(), - sse_keep_alive: None, - } - } -} - -#[derive(Debug)] -pub struct StreamableHttpServer { - transport_rx: tokio::sync::mpsc::UnboundedReceiver, - pub config: StreamableHttpServerConfig, -} - -impl StreamableHttpServer { - pub async fn serve(bind: SocketAddr) -> io::Result { - Self::serve_with_config(StreamableHttpServerConfig { - bind, - ..Default::default() - }) - .await - } - pub async fn serve_with_config(config: StreamableHttpServerConfig) -> io::Result { - let (streamable_http_server, service) = Self::new(config); - let listener = tokio::net::TcpListener::bind(streamable_http_server.config.bind).await?; - let ct = streamable_http_server.config.ct.child_token(); - let server = axum::serve(listener, service).with_graceful_shutdown(async move { - ct.cancelled().await; - tracing::info!("streamable http server cancelled"); - }); - tokio::spawn( - async move { - if let Err(e) = server.await { - tracing::error!(error = %e, "streamable http server shutdown with error"); - } - } - .instrument(tracing::info_span!("streamable-http-server", bind_address = %streamable_http_server.config.bind)), - ); - Ok(streamable_http_server) - } - - /// Warning: This function creates a new StreamableHttpServer instance with the provided configuration. - /// `App.post_path` may be incorrect if using `Router` as an embedded router. - pub fn new(config: StreamableHttpServerConfig) -> (StreamableHttpServer, Router) { - let (app, transport_rx) = - App::new(config.sse_keep_alive.unwrap_or(DEFAULT_AUTO_PING_INTERVAL)); - let router = Router::new() - .route( - &config.path, - get(get_handler).post(post_handler).delete(delete_handler), - ) - .with_state(app); - - let server = StreamableHttpServer { - transport_rx, - config, - }; - - (server, router) - } - - pub fn with_service(mut self, service_provider: F) -> CancellationToken - where - S: Service, - F: Fn() -> S + Send + 'static, - { - use crate::service::ServiceExt; - let ct = self.config.ct.clone(); - tokio::spawn(async move { - while let Some(transport) = self.next_transport().await { - let service = service_provider(); - let ct = self.config.ct.child_token(); - tokio::spawn(async move { - let server = service - .serve_with_ct(transport, ct) - .await - .map_err(tokio::io::Error::other)?; - server.waiting().await?; - tokio::io::Result::Ok(()) - }); - } - }); - ct - } - - pub fn cancel(&self) { - self.config.ct.cancel(); - } - - pub async fn next_transport(&mut self) -> Option { - self.transport_rx.recv().await - } -} diff --git a/crates/rmcp/src/transport/streamable_http_server/session.rs b/crates/rmcp/src/transport/streamable_http_server/session.rs index d1a608f4..3b775fed 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session.rs @@ -1,783 +1,49 @@ -use std::{ - borrow::Cow, - collections::{HashMap, HashSet, VecDeque}, - num::ParseIntError, - sync::Arc, -}; - -use thiserror::Error; -use tokio::sync::{ - mpsc::{Receiver, Sender}, - oneshot, -}; -use tracing::instrument; +use futures::Stream; +pub use crate::transport::common::sever_side_http::SessionId; use crate::{ RoleServer, - model::{ - CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest, - JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam, - ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification, - }, - transport::{ - WorkerTransport, - worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest}, - }, + model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, + transport::common::sever_side_http::ServerSseMessage, }; -#[derive(Debug, Clone)] -pub struct ServerSessionMessage { - pub event_id: EventId, - pub message: Arc, -} - -/// `/request_id>` -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct EventId { - http_request_id: Option, - index: usize, -} - -impl std::fmt::Display for EventId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.index)?; - match &self.http_request_id { - Some(http_request_id) => write!(f, "/{http_request_id}"), - None => write!(f, ""), - } - } -} - -#[derive(Debug, Clone, Error)] -pub enum EventIdParseError { - #[error("Invalid index: {0}")] - InvalidIndex(ParseIntError), - #[error("Invalid numeric request id: {0}")] - InvalidNumericRequestId(ParseIntError), - #[error("Missing request id type")] - InvalidRequestIdType, - #[error("Missing request id")] - MissingRequestId, -} - -impl std::str::FromStr for EventId { - type Err = EventIdParseError; - fn from_str(s: &str) -> Result { - if let Some((index, request_id)) = s.split_once("/") { - let index = usize::from_str(index).map_err(EventIdParseError::InvalidIndex)?; - let request_id = u64::from_str(request_id).map_err(EventIdParseError::InvalidIndex)?; - Ok(EventId { - http_request_id: Some(request_id), - index, - }) - } else { - let index = usize::from_str(s).map_err(EventIdParseError::InvalidIndex)?; - Ok(EventId { - http_request_id: None, - index, - }) - } - } -} - -pub use crate::transport::common::axum::SessionId; - -struct CachedTx { - tx: Sender, - cache: VecDeque, - http_request_id: Option, - capacity: usize, -} - -impl CachedTx { - fn new(tx: Sender, http_request_id: Option) -> Self { - Self { - cache: VecDeque::with_capacity(tx.capacity()), - capacity: tx.capacity(), - tx, - http_request_id, - } - } - fn new_common(tx: Sender) -> Self { - Self::new(tx, None) - } - - async fn send(&mut self, message: ServerJsonRpcMessage) { - let index = self.cache.back().map_or(0, |m| m.event_id.index + 1); - let event_id = EventId { - http_request_id: self.http_request_id, - index, - }; - let message = ServerSessionMessage { - event_id: event_id.clone(), - message: Arc::new(message), - }; - if self.cache.len() >= self.capacity { - self.cache.pop_front(); - self.cache.push_back(message.clone()); - } else { - self.cache.push_back(message.clone()); - } - let _ = self.tx.send(message).await.inspect_err(|e| { - let event_id = &e.0.event_id; - tracing::trace!(%event_id, "trying to send message in a closed session") - }); - } +pub mod local; +pub mod never; - async fn sync(&mut self, index: usize) -> Result<(), SessionError> { - let Some(front) = self.cache.front() else { - return Ok(()); - }; - let sync_index = index.saturating_sub(front.event_id.index); - if sync_index > self.cache.len() { - // invalid index - return Err(SessionError::InvalidEventId); - } - for message in self.cache.iter().skip(sync_index) { - let send_result = self.tx.send(message.clone()).await; - if send_result.is_err() { - return Err(SessionError::ChannelClosed( - message.event_id.http_request_id, - )); - } - } - Ok(()) - } -} - -struct HttpRequestWise { - resources: HashSet, - tx: CachedTx, -} - -type HttpRequestId = u64; -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -enum ResourceKey { - McpRequestId(RequestId), - ProgressToken(ProgressToken), -} - -pub struct SessionWorker { - id: SessionId, - next_http_request_id: HttpRequestId, - tx_router: HashMap, - resource_router: HashMap, - common: CachedTx, - event_rx: Receiver, - session_config: SessionConfig, -} - -impl SessionWorker { - pub fn id(&self) -> &SessionId { - &self.id - } -} - -#[derive(Debug, Error)] -pub enum SessionError { - #[error("Invalid request id: {0}")] - DuplicatedRequestId(HttpRequestId), - #[error("Channel closed: {0:?}")] - ChannelClosed(Option), - #[error("Cannot parse event id: {0}")] - EventIdParseError(Cow<'static, str>), - #[error("Session service terminated")] - SessionServiceTerminated, - #[error("Invalid event id")] - InvalidEventId, - #[error("Transport closed")] - TransportClosed, - #[error("IO error: {0}")] - Io(#[from] std::io::Error), - #[error("Tokio join error {0}")] - TokioJoinError(#[from] tokio::task::JoinError), -} - -impl From for std::io::Error { - fn from(value: SessionError) -> Self { - match value { - SessionError::Io(io) => io, - _ => std::io::Error::new(std::io::ErrorKind::Other, format!("Session error: {value}")), - } - } -} - -enum OutboundChannel { - RequestWise { id: HttpRequestId, close: bool }, - Common, -} - -pub struct StreamableHttpMessageReceiver { - pub http_request_id: Option, - pub inner: Receiver, -} - -impl SessionWorker { - fn unregister_resource(&mut self, resource: &ResourceKey) { - if let Some(http_request_id) = self.resource_router.remove(resource) { - tracing::trace!(?resource, http_request_id, "unregister resource"); - if let Some(channel) = self.tx_router.get_mut(&http_request_id) { - channel.resources.remove(resource); - if channel.resources.is_empty() { - tracing::debug!(http_request_id, "close http request wise channel"); - self.tx_router.remove(&http_request_id); - } - } - } - } - fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) { - tracing::trace!(?resource, http_request_id, "register resource"); - if let Some(channel) = self.tx_router.get_mut(&http_request_id) { - channel.resources.insert(resource.clone()); - self.resource_router.insert(resource, http_request_id); - } - } - fn register_request( - &mut self, - request: &JsonRpcRequest, - http_request_id: HttpRequestId, - ) { - use crate::model::GetMeta; - self.register_resource( - ResourceKey::McpRequestId(request.id.clone()), - http_request_id, - ); - if let Some(progress_token) = request.request.get_meta().get_progress_token() { - self.register_resource( - ResourceKey::ProgressToken(progress_token.clone()), - http_request_id, - ); - } - } - fn catch_cancellation_notification( - &mut self, - notification: &JsonRpcNotification, - ) { - if let ClientNotification::CancelledNotification(n) = ¬ification.notification { - let request_id = n.params.request_id.clone(); - let resource = ResourceKey::McpRequestId(request_id); - self.unregister_resource(&resource); - } - } - fn next_http_request_id(&mut self) -> HttpRequestId { - let id = self.next_http_request_id; - self.next_http_request_id = self.next_http_request_id.wrapping_add(1); - id - } - async fn establish_request_wise_channel( - &mut self, - ) -> Result { - let http_request_id = self.next_http_request_id(); - let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity); - self.tx_router.insert( - http_request_id, - HttpRequestWise { - resources: Default::default(), - tx: CachedTx::new(tx, Some(http_request_id)), - }, - ); - tracing::debug!(http_request_id, "establish new request wise channel"); - Ok(StreamableHttpMessageReceiver { - http_request_id: Some(http_request_id), - inner: rx, - }) - } - fn resolve_outbound_channel(&self, message: &ServerJsonRpcMessage) -> OutboundChannel { - match &message { - ServerJsonRpcMessage::Request(_) => OutboundChannel::Common, - ServerJsonRpcMessage::Notification(JsonRpcNotification { - notification: - ServerNotification::ProgressNotification(Notification { - params: ProgressNotificationParam { progress_token, .. }, - .. - }), - .. - }) => { - let id = self - .resource_router - .get(&ResourceKey::ProgressToken(progress_token.clone())); - - if let Some(id) = id { - OutboundChannel::RequestWise { - id: *id, - close: false, - } - } else { - OutboundChannel::Common - } - } - ServerJsonRpcMessage::Notification(JsonRpcNotification { - notification: - ServerNotification::CancelledNotification(Notification { - params: CancelledNotificationParam { request_id, .. }, - .. - }), - .. - }) => { - if let Some(id) = self - .resource_router - .get(&ResourceKey::McpRequestId(request_id.clone())) - { - OutboundChannel::RequestWise { - id: *id, - close: false, - } - } else { - OutboundChannel::Common - } - } - ServerJsonRpcMessage::Notification(_) => OutboundChannel::Common, - ServerJsonRpcMessage::Response(json_rpc_response) => { - if let Some(id) = self - .resource_router - .get(&ResourceKey::McpRequestId(json_rpc_response.id.clone())) - { - OutboundChannel::RequestWise { - id: *id, - close: false, - } - } else { - OutboundChannel::Common - } - } - ServerJsonRpcMessage::Error(json_rpc_error) => { - if let Some(id) = self - .resource_router - .get(&ResourceKey::McpRequestId(json_rpc_error.id.clone())) - { - OutboundChannel::RequestWise { - id: *id, - close: false, - } - } else { - OutboundChannel::Common - } - } - ServerJsonRpcMessage::BatchRequest(_) | ServerJsonRpcMessage::BatchResponse(_) => { - // the server side should never yield a batch request or response now - unreachable!("server side won't yield batch request or response") - } - } - } - async fn handle_server_message( - &mut self, - message: ServerJsonRpcMessage, - ) -> Result<(), SessionError> { - let outbound_channel = self.resolve_outbound_channel(&message); - match outbound_channel { - OutboundChannel::RequestWise { id, close } => { - if let Some(request_wise) = self.tx_router.get_mut(&id) { - request_wise.tx.send(message).await; - if close { - self.tx_router.remove(&id); - } - } else { - return Err(SessionError::ChannelClosed(Some(id))); - } - } - OutboundChannel::Common => self.common.send(message).await, - } - Ok(()) - } - async fn resume( - &mut self, - last_event_id: EventId, - ) -> Result { - match last_event_id.http_request_id { - Some(http_request_id) => { - let request_wise = self - .tx_router - .get_mut(&http_request_id) - .ok_or(SessionError::ChannelClosed(Some(http_request_id)))?; - let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity); - let (tx, rx) = channel; - request_wise.tx.tx = tx; - let index = last_event_id.index; - // sync messages after index - request_wise.tx.sync(index).await?; - Ok(StreamableHttpMessageReceiver { - http_request_id: Some(http_request_id), - inner: rx, - }) - } - None => { - let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity); - let (tx, rx) = channel; - self.common.tx = tx; - let index = last_event_id.index; - // sync messages after index - self.common.sync(index).await?; - Ok(StreamableHttpMessageReceiver { - http_request_id: None, - inner: rx, - }) - } - } - } -} - -enum SessionEvent { - ClientMessage { - message: ClientJsonRpcMessage, - http_request_id: Option, - }, - EstablishRequestWiseChannel { - responder: oneshot::Sender>, - }, - CloseRequestWiseChannel { - id: HttpRequestId, - responder: oneshot::Sender>, - }, - Resume { - last_event_id: EventId, - responder: oneshot::Sender>, - }, - InitializeRequest { - request: ClientJsonRpcMessage, - responder: oneshot::Sender>, - }, - Close, -} - -#[derive(Debug, Clone)] -pub enum SessionQuitReason { - ServiceTerminated, - ClientTerminated, - ExpectInitializeRequest, - ExpectInitializeResponse, - Cancelled, -} - -#[derive(Debug, Clone)] -pub struct SessionHandle { - id: SessionId, - // after all event_tx drop, inner task will be terminated - event_tx: Sender, -} - -impl SessionHandle { - /// Get the session id - pub fn id(&self) -> &SessionId { - &self.id - } - - /// Close the session - pub async fn close(&self) -> Result<(), SessionError> { - self.event_tx - .send(SessionEvent::Close) - .await - .map_err(|_| SessionError::SessionServiceTerminated)?; - Ok(()) - } - - /// Send a message to the session - pub async fn push_message( +pub trait SessionManager: Send + Sync + 'static { + type Error: std::error::Error + Send + 'static; + type Transport: crate::transport::Transport; + /// Create a new session with the given id and configuration. + fn create_session( &self, - message: ClientJsonRpcMessage, - http_request_id: Option, - ) -> Result<(), SessionError> { - self.event_tx - .send(SessionEvent::ClientMessage { - message, - http_request_id, - }) - .await - .map_err(|_| SessionError::SessionServiceTerminated)?; - Ok(()) - } - - /// establish a channel for a http-request, the corresponded message from server will be - /// sent through this channel. The channel will be closed when the request is completed, - /// or you can close it manually by calling [`SessionHandle::close_request_wise_channel`]. - pub async fn establish_request_wise_channel( - &self, - ) -> Result { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.event_tx - .send(SessionEvent::EstablishRequestWiseChannel { responder: tx }) - .await - .map_err(|_| SessionError::SessionServiceTerminated)?; - rx.await - .map_err(|_| SessionError::SessionServiceTerminated)? - } - - /// close the http-request wise channel. - pub async fn close_request_wise_channel( + ) -> impl Future> + Send; + fn initialize_session( &self, - request_id: HttpRequestId, - ) -> Result<(), SessionError> { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.event_tx - .send(SessionEvent::CloseRequestWiseChannel { - id: request_id, - responder: tx, - }) - .await - .map_err(|_| SessionError::SessionServiceTerminated)?; - rx.await - .map_err(|_| SessionError::SessionServiceTerminated)? - } - - /// Establish a common channel for general purpose messages. - pub async fn establish_common_channel( + id: &SessionId, + message: ClientJsonRpcMessage, + ) -> impl Future> + Send; + fn has_session(&self, id: &SessionId) + -> impl Future> + Send; + fn close_session(&self, id: &SessionId) + -> impl Future> + Send; + fn create_stream( &self, - ) -> Result { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.event_tx - .send(SessionEvent::Resume { - last_event_id: EventId { - http_request_id: None, - index: 0, - }, - responder: tx, - }) - .await - .map_err(|_| SessionError::SessionServiceTerminated)?; - rx.await - .map_err(|_| SessionError::SessionServiceTerminated)? - } - - /// Resume streaming response by the last event id. This is suitable for both request wise and common channel. - pub async fn resume( + id: &SessionId, + message: ClientJsonRpcMessage, + ) -> impl Future< + Output = Result + Send + 'static, Self::Error>, + > + Send; + fn create_stantalone_stream( &self, - last_event_id: EventId, - ) -> Result { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.event_tx - .send(SessionEvent::Resume { - last_event_id, - responder: tx, - }) - .await - .map_err(|_| SessionError::SessionServiceTerminated)?; - rx.await - .map_err(|_| SessionError::SessionServiceTerminated)? - } - - /// Send an initialize request to the session. And wait for the initialized response. - /// - /// This is used to establish a session with the server. - pub async fn initialize( + id: &SessionId, + ) -> impl Future< + Output = Result + Send + 'static, Self::Error>, + > + Send; + fn resume( &self, - request: ClientJsonRpcMessage, - ) -> Result { - let (tx, rx) = tokio::sync::oneshot::channel(); - self.event_tx - .send(SessionEvent::InitializeRequest { - request, - responder: tx, - }) - .await - .map_err(|_| SessionError::SessionServiceTerminated)?; - rx.await - .map_err(|_| SessionError::SessionServiceTerminated)? - } -} - -pub type SessionTransport = WorkerTransport; - -impl Worker for SessionWorker { - type Error = SessionError; - type Role = RoleServer; - fn err_closed() -> Self::Error { - SessionError::TransportClosed - } - fn err_join(e: tokio::task::JoinError) -> Self::Error { - SessionError::TokioJoinError(e) - } - fn config(&self) -> crate::transport::worker::WorkerConfig { - crate::transport::worker::WorkerConfig { - name: Some(format!("streamable-http-session-{}", self.id)), - channel_buffer_capacity: self.session_config.channel_capacity, - } - } - #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))] - async fn run(mut self, mut context: WorkerContext) -> Result<(), WorkerQuitReason> { - enum InnerEvent { - FromHttpService(SessionEvent), - FromHandler(WorkerSendRequest), - } - // waiting for initialize request - let evt = self.event_rx.recv().await.ok_or_else(|| { - WorkerQuitReason::fatal("transport terminated", "get initialize request") - })?; - let SessionEvent::InitializeRequest { request, responder } = evt else { - return Err(WorkerQuitReason::fatal( - "unexpected message", - "get initialize request", - )); - }; - context.send_to_handler(request).await?; - let send_initialize_response = context.recv_from_handler().await?; - responder - .send(Ok(send_initialize_response.message)) - .map_err(|_| { - WorkerQuitReason::fatal( - "failed to send initialize response to http service", - "send initialize response", - ) - })?; - send_initialize_response - .responder - .send(Ok(())) - .map_err(|_| WorkerQuitReason::HandlerTerminated)?; - let ct = context.cancellation_token.clone(); - loop { - let event = tokio::select! { - event = self.event_rx.recv() => { - if let Some(event) = event { - InnerEvent::FromHttpService(event) - } else { - return Err(WorkerQuitReason::fatal("session dropped", "waiting next session event")) - } - }, - from_handler = context.recv_from_handler() => { - InnerEvent::FromHandler(from_handler?) - } - _ = ct.cancelled() => { - return Err(WorkerQuitReason::Cancelled) - } - }; - match event { - InnerEvent::FromHandler(WorkerSendRequest { message, responder }) => { - // catch response - match &message { - crate::model::JsonRpcMessage::Response(json_rpc_response) => { - let request_id = json_rpc_response.id.clone(); - self.unregister_resource(&ResourceKey::McpRequestId(request_id)); - } - crate::model::JsonRpcMessage::Error(json_rpc_error) => { - let request_id = json_rpc_error.id.clone(); - self.unregister_resource(&ResourceKey::McpRequestId(request_id)); - } - // unlikely happen - crate::model::JsonRpcMessage::BatchResponse( - json_rpc_batch_response_items, - ) => { - for item in json_rpc_batch_response_items { - let request_id = match item { - crate::model::JsonRpcBatchResponseItem::Response( - json_rpc_response, - ) => json_rpc_response.id.clone(), - crate::model::JsonRpcBatchResponseItem::Error( - json_rpc_error, - ) => json_rpc_error.id.clone(), - }; - self.unregister_resource(&ResourceKey::McpRequestId(request_id)); - } - } - _ => { - // no need to unregister resource - } - } - let handle_result = self.handle_server_message(message).await; - let _ = responder.send(handle_result); - } - InnerEvent::FromHttpService(SessionEvent::ClientMessage { - message: json_rpc_message, - http_request_id, - }) => { - match &json_rpc_message { - crate::model::JsonRpcMessage::Request(request) => { - if let Some(http_request_id) = http_request_id { - self.register_request(request, http_request_id) - } - } - crate::model::JsonRpcMessage::Notification(notification) => { - self.catch_cancellation_notification(notification) - } - crate::model::JsonRpcMessage::BatchRequest(items) => { - for r in items { - match r { - crate::model::JsonRpcBatchRequestItem::Request(request) => { - if let Some(http_request_id) = http_request_id { - self.register_request(request, http_request_id) - } - } - crate::model::JsonRpcBatchRequestItem::Notification( - notification, - ) => self.catch_cancellation_notification(notification), - } - } - } - _ => {} - } - context.send_to_handler(json_rpc_message).await?; - } - InnerEvent::FromHttpService(SessionEvent::EstablishRequestWiseChannel { - responder, - }) => { - let handle_result = self.establish_request_wise_channel().await; - let _ = responder.send(handle_result); - } - InnerEvent::FromHttpService(SessionEvent::CloseRequestWiseChannel { - id, - responder, - }) => { - let _handle_result = self.tx_router.remove(&id); - let _ = responder.send(Ok(())); - } - InnerEvent::FromHttpService(SessionEvent::Resume { - last_event_id, - responder, - }) => { - let handle_result = self.resume(last_event_id).await; - let _ = responder.send(handle_result); - } - InnerEvent::FromHttpService(SessionEvent::Close) => { - return Err(WorkerQuitReason::TransportClosed); - } - _ => { - // ignore - } - } - } - } -} - -#[derive(Debug, Clone)] -pub struct SessionConfig { - channel_capacity: usize, -} - -impl SessionConfig { - pub const DEFAULT_CHANNEL_CAPACITY: usize = 16; -} - -impl Default for SessionConfig { - fn default() -> Self { - Self { - channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY, - } - } -} - -/// Create a new session with the given id and configuration. -/// -/// This function will return a pair of [`SessionHandle`] and [`SessionWorker`]. -/// -/// You can run the [`SessionWorker`] as a transport for mcp server. And use the [`SessionHandle`] operate the session. -pub fn create_session( - id: impl Into, - config: SessionConfig, -) -> (SessionHandle, SessionWorker) { - let id = id.into(); - let (event_tx, event_rx) = tokio::sync::mpsc::channel(config.channel_capacity); - let (common_tx, _) = tokio::sync::mpsc::channel(config.channel_capacity); - let common = CachedTx::new_common(common_tx); - tracing::info!(session_id = ?id, "create new session"); - let handle = SessionHandle { - event_tx, - id: id.clone(), - }; - let session_worker = SessionWorker { - next_http_request_id: 0, - id, - tx_router: HashMap::new(), - resource_router: HashMap::new(), - common, - event_rx, - session_config: config.clone(), - }; - (handle, session_worker) + id: &SessionId, + last_event_id: String, + ) -> impl Future< + Output = Result + Send + 'static, Self::Error>, + > + Send; } diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs new file mode 100644 index 00000000..468e12bd --- /dev/null +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -0,0 +1,880 @@ +use std::{ + collections::{HashMap, HashSet, VecDeque}, + num::ParseIntError, + sync::Arc, + time::Duration, +}; + +use futures::Stream; +use thiserror::Error; +use tokio::sync::{ + mpsc::{Receiver, Sender}, + oneshot, +}; +use tokio_stream::wrappers::ReceiverStream; +use tracing::instrument; + +use crate::{ + RoleServer, + model::{ + CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest, + JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam, + ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification, + }, + transport::{ + WorkerTransport, + common::sever_side_http::{SessionId, session_id}, + worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest}, + }, +}; + +#[derive(Debug, Default)] +pub struct LocalSessionManager { + pub sessions: tokio::sync::RwLock>, + pub session_config: SessionConfig, +} + +#[derive(Debug, Error)] +pub enum LocalSessionManagerError { + #[error("Session not found: {0}")] + SessionNotFound(SessionId), + #[error("Session error: {0}")] + SessionError(#[from] SessionError), + #[error("Invalid event id: {0}")] + InvalidEventId(#[from] EventIdParseError), +} +impl SessionManager for LocalSessionManager { + type Error = LocalSessionManagerError; + type Transport = WorkerTransport; + async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> { + let id = session_id(); + let (handle, worker) = create_local_session(id.clone(), self.session_config.clone()); + self.sessions.write().await.insert(id.clone(), handle); + Ok((id, WorkerTransport::spawn(worker))) + } + async fn initialize_session( + &self, + id: &SessionId, + message: ClientJsonRpcMessage, + ) -> Result { + let sessions = self.sessions.read().await; + let handle = sessions + .get(id) + .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?; + let response = handle.initialize(message).await?; + Ok(response) + } + async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> { + let mut sessions = self.sessions.write().await; + if let Some(handle) = sessions.remove(id) { + handle.close().await?; + } + Ok(()) + } + async fn has_session(&self, id: &SessionId) -> Result { + let sessions = self.sessions.read().await; + Ok(sessions.contains_key(id)) + } + async fn create_stream( + &self, + id: &SessionId, + message: ClientJsonRpcMessage, + ) -> Result + Send + 'static, Self::Error> { + let sessions = self.sessions.read().await; + let handle = sessions + .get(id) + .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?; + let receiver = handle.establish_request_wise_channel().await?; + handle + .push_message(message, receiver.http_request_id) + .await?; + Ok(ReceiverStream::new(receiver.inner)) + } + + async fn create_stantalone_stream( + &self, + id: &SessionId, + ) -> Result + Send + 'static, Self::Error> { + let sessions = self.sessions.read().await; + let handle = sessions + .get(id) + .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?; + let receiver = handle.establish_common_channel().await?; + Ok(ReceiverStream::new(receiver.inner)) + } + + async fn resume( + &self, + id: &SessionId, + last_event_id: String, + ) -> Result + Send + 'static, Self::Error> { + let sessions = self.sessions.read().await; + let handle = sessions + .get(id) + .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?; + let receiver = handle.resume(last_event_id.parse()?).await?; + Ok(ReceiverStream::new(receiver.inner)) + } +} + +/// `/request_id>` +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct EventId { + http_request_id: Option, + index: usize, +} + +impl std::fmt::Display for EventId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.index)?; + match &self.http_request_id { + Some(http_request_id) => write!(f, "/{http_request_id}"), + None => write!(f, ""), + } + } +} + +#[derive(Debug, Clone, Error)] +pub enum EventIdParseError { + #[error("Invalid index: {0}")] + InvalidIndex(ParseIntError), + #[error("Invalid numeric request id: {0}")] + InvalidNumericRequestId(ParseIntError), + #[error("Missing request id type")] + InvalidRequestIdType, + #[error("Missing request id")] + MissingRequestId, +} + +impl std::str::FromStr for EventId { + type Err = EventIdParseError; + fn from_str(s: &str) -> Result { + if let Some((index, request_id)) = s.split_once("/") { + let index = usize::from_str(index).map_err(EventIdParseError::InvalidIndex)?; + let request_id = u64::from_str(request_id).map_err(EventIdParseError::InvalidIndex)?; + Ok(EventId { + http_request_id: Some(request_id), + index, + }) + } else { + let index = usize::from_str(s).map_err(EventIdParseError::InvalidIndex)?; + Ok(EventId { + http_request_id: None, + index, + }) + } + } +} + +use super::{ServerSseMessage, SessionManager}; + +struct CachedTx { + tx: Sender, + cache: VecDeque, + http_request_id: Option, + capacity: usize, +} + +impl CachedTx { + fn new(tx: Sender, http_request_id: Option) -> Self { + Self { + cache: VecDeque::with_capacity(tx.capacity()), + capacity: tx.capacity(), + tx, + http_request_id, + } + } + fn new_common(tx: Sender) -> Self { + Self::new(tx, None) + } + + async fn send(&mut self, message: ServerJsonRpcMessage) { + let index = self.cache.back().map_or(0, |m| { + m.event_id + .as_deref() + .unwrap_or_default() + .parse::() + .expect("valid event id") + .index + + 1 + }); + let event_id = EventId { + http_request_id: self.http_request_id, + index, + }; + let message = ServerSseMessage { + event_id: Some(event_id.to_string()), + message: Arc::new(message), + }; + if self.cache.len() >= self.capacity { + self.cache.pop_front(); + self.cache.push_back(message.clone()); + } else { + self.cache.push_back(message.clone()); + } + let _ = self.tx.send(message).await.inspect_err(|e| { + let event_id = &e.0.event_id; + tracing::trace!(?event_id, "trying to send message in a closed session") + }); + } + + async fn sync(&mut self, index: usize) -> Result<(), SessionError> { + let Some(front) = self.cache.front() else { + return Ok(()); + }; + let front_event_id = front + .event_id + .as_deref() + .unwrap_or_default() + .parse::()?; + let sync_index = index.saturating_sub(front_event_id.index); + if sync_index > self.cache.len() { + // invalid index + return Err(SessionError::InvalidEventId); + } + for message in self.cache.iter().skip(sync_index) { + let send_result = self.tx.send(message.clone()).await; + if send_result.is_err() { + let event_id: EventId = message.event_id.as_deref().unwrap_or_default().parse()?; + return Err(SessionError::ChannelClosed(Some(event_id.index as u64))); + } + } + Ok(()) + } +} + +struct HttpRequestWise { + resources: HashSet, + tx: CachedTx, +} + +type HttpRequestId = u64; +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +enum ResourceKey { + McpRequestId(RequestId), + ProgressToken(ProgressToken), +} + +pub struct LocalSessionWorker { + id: SessionId, + next_http_request_id: HttpRequestId, + tx_router: HashMap, + resource_router: HashMap, + common: CachedTx, + event_rx: Receiver, + session_config: SessionConfig, +} + +impl LocalSessionWorker { + pub fn id(&self) -> &SessionId { + &self.id + } +} + +#[derive(Debug, Error)] +pub enum SessionError { + #[error("Invalid request id: {0}")] + DuplicatedRequestId(HttpRequestId), + #[error("Channel closed: {0:?}")] + ChannelClosed(Option), + #[error("Cannot parse event id: {0}")] + EventIdParseError(#[from] EventIdParseError), + #[error("Session service terminated")] + SessionServiceTerminated, + #[error("Invalid event id")] + InvalidEventId, + #[error("Transport closed")] + TransportClosed, + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Tokio join error {0}")] + TokioJoinError(#[from] tokio::task::JoinError), +} + +impl From for std::io::Error { + fn from(value: SessionError) -> Self { + match value { + SessionError::Io(io) => io, + _ => std::io::Error::new(std::io::ErrorKind::Other, format!("Session error: {value}")), + } + } +} + +enum OutboundChannel { + RequestWise { id: HttpRequestId, close: bool }, + Common, +} + +pub struct StreamableHttpMessageReceiver { + pub http_request_id: Option, + pub inner: Receiver, +} + +impl LocalSessionWorker { + fn unregister_resource(&mut self, resource: &ResourceKey) { + if let Some(http_request_id) = self.resource_router.remove(resource) { + tracing::trace!(?resource, http_request_id, "unregister resource"); + if let Some(channel) = self.tx_router.get_mut(&http_request_id) { + channel.resources.remove(resource); + if channel.resources.is_empty() { + tracing::debug!(http_request_id, "close http request wise channel"); + self.tx_router.remove(&http_request_id); + } + } + } + } + fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) { + tracing::trace!(?resource, http_request_id, "register resource"); + if let Some(channel) = self.tx_router.get_mut(&http_request_id) { + channel.resources.insert(resource.clone()); + self.resource_router.insert(resource, http_request_id); + } + } + fn register_request( + &mut self, + request: &JsonRpcRequest, + http_request_id: HttpRequestId, + ) { + use crate::model::GetMeta; + self.register_resource( + ResourceKey::McpRequestId(request.id.clone()), + http_request_id, + ); + if let Some(progress_token) = request.request.get_meta().get_progress_token() { + self.register_resource( + ResourceKey::ProgressToken(progress_token.clone()), + http_request_id, + ); + } + } + fn catch_cancellation_notification( + &mut self, + notification: &JsonRpcNotification, + ) { + if let ClientNotification::CancelledNotification(n) = ¬ification.notification { + let request_id = n.params.request_id.clone(); + let resource = ResourceKey::McpRequestId(request_id); + self.unregister_resource(&resource); + } + } + fn next_http_request_id(&mut self) -> HttpRequestId { + let id = self.next_http_request_id; + self.next_http_request_id = self.next_http_request_id.wrapping_add(1); + id + } + async fn establish_request_wise_channel( + &mut self, + ) -> Result { + let http_request_id = self.next_http_request_id(); + let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity); + self.tx_router.insert( + http_request_id, + HttpRequestWise { + resources: Default::default(), + tx: CachedTx::new(tx, Some(http_request_id)), + }, + ); + tracing::debug!(http_request_id, "establish new request wise channel"); + Ok(StreamableHttpMessageReceiver { + http_request_id: Some(http_request_id), + inner: rx, + }) + } + fn resolve_outbound_channel(&self, message: &ServerJsonRpcMessage) -> OutboundChannel { + match &message { + ServerJsonRpcMessage::Request(_) => OutboundChannel::Common, + ServerJsonRpcMessage::Notification(JsonRpcNotification { + notification: + ServerNotification::ProgressNotification(Notification { + params: ProgressNotificationParam { progress_token, .. }, + .. + }), + .. + }) => { + let id = self + .resource_router + .get(&ResourceKey::ProgressToken(progress_token.clone())); + + if let Some(id) = id { + OutboundChannel::RequestWise { + id: *id, + close: false, + } + } else { + OutboundChannel::Common + } + } + ServerJsonRpcMessage::Notification(JsonRpcNotification { + notification: + ServerNotification::CancelledNotification(Notification { + params: CancelledNotificationParam { request_id, .. }, + .. + }), + .. + }) => { + if let Some(id) = self + .resource_router + .get(&ResourceKey::McpRequestId(request_id.clone())) + { + OutboundChannel::RequestWise { + id: *id, + close: false, + } + } else { + OutboundChannel::Common + } + } + ServerJsonRpcMessage::Notification(_) => OutboundChannel::Common, + ServerJsonRpcMessage::Response(json_rpc_response) => { + if let Some(id) = self + .resource_router + .get(&ResourceKey::McpRequestId(json_rpc_response.id.clone())) + { + OutboundChannel::RequestWise { + id: *id, + close: false, + } + } else { + OutboundChannel::Common + } + } + ServerJsonRpcMessage::Error(json_rpc_error) => { + if let Some(id) = self + .resource_router + .get(&ResourceKey::McpRequestId(json_rpc_error.id.clone())) + { + OutboundChannel::RequestWise { + id: *id, + close: false, + } + } else { + OutboundChannel::Common + } + } + ServerJsonRpcMessage::BatchRequest(_) | ServerJsonRpcMessage::BatchResponse(_) => { + // the server side should never yield a batch request or response now + unreachable!("server side won't yield batch request or response") + } + } + } + async fn handle_server_message( + &mut self, + message: ServerJsonRpcMessage, + ) -> Result<(), SessionError> { + let outbound_channel = self.resolve_outbound_channel(&message); + match outbound_channel { + OutboundChannel::RequestWise { id, close } => { + if let Some(request_wise) = self.tx_router.get_mut(&id) { + request_wise.tx.send(message).await; + if close { + self.tx_router.remove(&id); + } + } else { + return Err(SessionError::ChannelClosed(Some(id))); + } + } + OutboundChannel::Common => self.common.send(message).await, + } + Ok(()) + } + async fn resume( + &mut self, + last_event_id: EventId, + ) -> Result { + match last_event_id.http_request_id { + Some(http_request_id) => { + let request_wise = self + .tx_router + .get_mut(&http_request_id) + .ok_or(SessionError::ChannelClosed(Some(http_request_id)))?; + let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity); + let (tx, rx) = channel; + request_wise.tx.tx = tx; + let index = last_event_id.index; + // sync messages after index + request_wise.tx.sync(index).await?; + Ok(StreamableHttpMessageReceiver { + http_request_id: Some(http_request_id), + inner: rx, + }) + } + None => { + let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity); + let (tx, rx) = channel; + self.common.tx = tx; + let index = last_event_id.index; + // sync messages after index + self.common.sync(index).await?; + Ok(StreamableHttpMessageReceiver { + http_request_id: None, + inner: rx, + }) + } + } + } +} + +enum SessionEvent { + ClientMessage { + message: ClientJsonRpcMessage, + http_request_id: Option, + }, + EstablishRequestWiseChannel { + responder: oneshot::Sender>, + }, + CloseRequestWiseChannel { + id: HttpRequestId, + responder: oneshot::Sender>, + }, + Resume { + last_event_id: EventId, + responder: oneshot::Sender>, + }, + InitializeRequest { + request: ClientJsonRpcMessage, + responder: oneshot::Sender>, + }, + Close, +} + +#[derive(Debug, Clone)] +pub enum SessionQuitReason { + ServiceTerminated, + ClientTerminated, + ExpectInitializeRequest, + ExpectInitializeResponse, + Cancelled, +} + +#[derive(Debug, Clone)] +pub struct LocalSessionHandle { + id: SessionId, + // after all event_tx drop, inner task will be terminated + event_tx: Sender, +} + +impl LocalSessionHandle { + /// Get the session id + pub fn id(&self) -> &SessionId { + &self.id + } + + /// Close the session + pub async fn close(&self) -> Result<(), SessionError> { + self.event_tx + .send(SessionEvent::Close) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + Ok(()) + } + + /// Send a message to the session + pub async fn push_message( + &self, + message: ClientJsonRpcMessage, + http_request_id: Option, + ) -> Result<(), SessionError> { + self.event_tx + .send(SessionEvent::ClientMessage { + message, + http_request_id, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + Ok(()) + } + + /// establish a channel for a http-request, the corresponded message from server will be + /// sent through this channel. The channel will be closed when the request is completed, + /// or you can close it manually by calling [`LocalSessionHandle::close_request_wise_channel`]. + pub async fn establish_request_wise_channel( + &self, + ) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::EstablishRequestWiseChannel { responder: tx }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + + /// close the http-request wise channel. + pub async fn close_request_wise_channel( + &self, + request_id: HttpRequestId, + ) -> Result<(), SessionError> { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::CloseRequestWiseChannel { + id: request_id, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + + /// Establish a common channel for general purpose messages. + pub async fn establish_common_channel( + &self, + ) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::Resume { + last_event_id: EventId { + http_request_id: None, + index: 0, + }, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + + /// Resume streaming response by the last event id. This is suitable for both request wise and common channel. + pub async fn resume( + &self, + last_event_id: EventId, + ) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::Resume { + last_event_id, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + + /// Send an initialize request to the session. And wait for the initialized response. + /// + /// This is used to establish a session with the server. + pub async fn initialize( + &self, + request: ClientJsonRpcMessage, + ) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::InitializeRequest { + request, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } +} + +pub type SessionTransport = WorkerTransport; + +impl Worker for LocalSessionWorker { + type Error = SessionError; + type Role = RoleServer; + fn err_closed() -> Self::Error { + SessionError::TransportClosed + } + fn err_join(e: tokio::task::JoinError) -> Self::Error { + SessionError::TokioJoinError(e) + } + fn config(&self) -> crate::transport::worker::WorkerConfig { + crate::transport::worker::WorkerConfig { + name: Some(format!("streamable-http-session-{}", self.id)), + channel_buffer_capacity: self.session_config.channel_capacity, + } + } + #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))] + async fn run(mut self, mut context: WorkerContext) -> Result<(), WorkerQuitReason> { + enum InnerEvent { + FromHttpService(SessionEvent), + FromHandler(WorkerSendRequest), + } + // waiting for initialize request + let evt = self.event_rx.recv().await.ok_or_else(|| { + WorkerQuitReason::fatal("transport terminated", "get initialize request") + })?; + let SessionEvent::InitializeRequest { request, responder } = evt else { + return Err(WorkerQuitReason::fatal( + "unexpected message", + "get initialize request", + )); + }; + context.send_to_handler(request).await?; + let send_initialize_response = context.recv_from_handler().await?; + responder + .send(Ok(send_initialize_response.message)) + .map_err(|_| { + WorkerQuitReason::fatal( + "failed to send initialize response to http service", + "send initialize response", + ) + })?; + send_initialize_response + .responder + .send(Ok(())) + .map_err(|_| WorkerQuitReason::HandlerTerminated)?; + let ct = context.cancellation_token.clone(); + let keep_alive = self.session_config.keep_alive.unwrap_or(Duration::MAX); + loop { + let keep_alive_timeout = tokio::time::sleep(keep_alive); + let event = tokio::select! { + event = self.event_rx.recv() => { + if let Some(event) = event { + InnerEvent::FromHttpService(event) + } else { + return Err(WorkerQuitReason::fatal("session dropped", "waiting next session event")) + } + }, + from_handler = context.recv_from_handler() => { + InnerEvent::FromHandler(from_handler?) + } + _ = ct.cancelled() => { + return Err(WorkerQuitReason::Cancelled) + } + _ = keep_alive_timeout => { + return Err(WorkerQuitReason::fatal("keep live timeout", "poll next session event")) + } + }; + match event { + InnerEvent::FromHandler(WorkerSendRequest { message, responder }) => { + // catch response + let to_unregister = match &message { + crate::model::JsonRpcMessage::Response(json_rpc_response) => { + let request_id = json_rpc_response.id.clone(); + Some(ResourceKey::McpRequestId(request_id)) + } + crate::model::JsonRpcMessage::Error(json_rpc_error) => { + let request_id = json_rpc_error.id.clone(); + Some(ResourceKey::McpRequestId(request_id)) + } + _ => { + None + // no need to unregister resource + } + }; + let handle_result = self.handle_server_message(message).await; + let _ = responder.send(handle_result).inspect_err(|error| { + tracing::warn!(?error, "failed to send message to http service handler"); + }); + if let Some(to_unregister) = to_unregister { + self.unregister_resource(&to_unregister); + } + } + InnerEvent::FromHttpService(SessionEvent::ClientMessage { + message: json_rpc_message, + http_request_id, + }) => { + match &json_rpc_message { + crate::model::JsonRpcMessage::Request(request) => { + if let Some(http_request_id) = http_request_id { + self.register_request(request, http_request_id) + } + } + crate::model::JsonRpcMessage::Notification(notification) => { + self.catch_cancellation_notification(notification) + } + crate::model::JsonRpcMessage::BatchRequest(items) => { + for r in items { + match r { + crate::model::JsonRpcBatchRequestItem::Request(request) => { + if let Some(http_request_id) = http_request_id { + self.register_request(request, http_request_id) + } + } + crate::model::JsonRpcBatchRequestItem::Notification( + notification, + ) => self.catch_cancellation_notification(notification), + } + } + } + _ => {} + } + context.send_to_handler(json_rpc_message).await?; + } + InnerEvent::FromHttpService(SessionEvent::EstablishRequestWiseChannel { + responder, + }) => { + let handle_result = self.establish_request_wise_channel().await; + let _ = responder.send(handle_result); + } + InnerEvent::FromHttpService(SessionEvent::CloseRequestWiseChannel { + id, + responder, + }) => { + let _handle_result = self.tx_router.remove(&id); + let _ = responder.send(Ok(())); + } + InnerEvent::FromHttpService(SessionEvent::Resume { + last_event_id, + responder, + }) => { + let handle_result = self.resume(last_event_id).await; + let _ = responder.send(handle_result); + } + InnerEvent::FromHttpService(SessionEvent::Close) => { + return Err(WorkerQuitReason::TransportClosed); + } + _ => { + // ignore + } + } + } + } +} + +#[derive(Debug, Clone)] +pub struct SessionConfig { + /// the capacity of the channel for the session. Default is 16. + pub channel_capacity: usize, + /// if set, the session will be closed after this duration of inactivity. + pub keep_alive: Option, +} + +impl SessionConfig { + pub const DEFAULT_CHANNEL_CAPACITY: usize = 16; +} + +impl Default for SessionConfig { + fn default() -> Self { + Self { + channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY, + keep_alive: None, + } + } +} + +/// Create a new session with the given id and configuration. +/// +/// This function will return a pair of [`LocalSessionHandle`] and [`LocalSessionWorker`]. +/// +/// You can run the [`LocalSessionWorker`] as a transport for mcp server. And use the [`LocalSessionHandle`] operate the session. +pub fn create_local_session( + id: impl Into, + config: SessionConfig, +) -> (LocalSessionHandle, LocalSessionWorker) { + let id = id.into(); + let (event_tx, event_rx) = tokio::sync::mpsc::channel(config.channel_capacity); + let (common_tx, _) = tokio::sync::mpsc::channel(config.channel_capacity); + let common = CachedTx::new_common(common_tx); + tracing::info!(session_id = ?id, "create new session"); + let handle = LocalSessionHandle { + event_tx, + id: id.clone(), + }; + let session_worker = LocalSessionWorker { + next_http_request_id: 0, + id, + tx_router: HashMap::new(), + resource_router: HashMap::new(), + common, + event_rx, + session_config: config.clone(), + }; + (handle, session_worker) +} diff --git a/crates/rmcp/src/transport/streamable_http_server/session/never.rs b/crates/rmcp/src/transport/streamable_http_server/session/never.rs new file mode 100644 index 00000000..a7e84a58 --- /dev/null +++ b/crates/rmcp/src/transport/streamable_http_server/session/never.rs @@ -0,0 +1,100 @@ +use futures::Stream; +use thiserror::Error; + +use super::{ServerSseMessage, SessionId, SessionManager}; +use crate::{ + RoleServer, + model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, + transport::Transport, +}; + +#[derive(Debug, Clone, Error)] +#[error("Session management is not supported")] +pub struct ErrorSessionManagementNotSupported; +#[derive(Debug, Clone, Default)] +pub struct NeverSessionManager {} +pub enum NeverTransport {} +impl Transport for NeverTransport { + type Error = ErrorSessionManagementNotSupported; + + fn send( + &mut self, + _item: ServerJsonRpcMessage, + ) -> impl Future> + Send + 'static { + futures::future::ready(Err(ErrorSessionManagementNotSupported)) + } + + fn receive(&mut self) -> impl Future> { + futures::future::ready(None) + } + + async fn close(&mut self) -> Result<(), Self::Error> { + Err(ErrorSessionManagementNotSupported) + } +} + +impl SessionManager for NeverSessionManager { + type Error = ErrorSessionManagementNotSupported; + type Transport = NeverTransport; + + fn create_session( + &self, + ) -> impl Future> + Send { + futures::future::ready(Err(ErrorSessionManagementNotSupported)) + } + + fn initialize_session( + &self, + _id: &SessionId, + _message: ClientJsonRpcMessage, + ) -> impl Future> + Send { + futures::future::ready(Err(ErrorSessionManagementNotSupported)) + } + + fn has_session( + &self, + _id: &SessionId, + ) -> impl Future> + Send { + futures::future::ready(Err(ErrorSessionManagementNotSupported)) + } + + fn close_session( + &self, + _id: &SessionId, + ) -> impl Future> + Send { + futures::future::ready(Err(ErrorSessionManagementNotSupported)) + } + + fn create_stream( + &self, + _id: &SessionId, + _message: ClientJsonRpcMessage, + ) -> impl Future< + Output = Result + Send + 'static, Self::Error>, + > + Send { + futures::future::ready(Result::, _>::Err( + ErrorSessionManagementNotSupported, + )) + } + fn create_stantalone_stream( + &self, + _id: &SessionId, + ) -> impl Future< + Output = Result + Send + 'static, Self::Error>, + > + Send { + futures::future::ready(Result::, _>::Err( + ErrorSessionManagementNotSupported, + )) + } + fn resume( + &self, + _id: &SessionId, + _last_event_id: String, + ) -> impl Future< + Output = Result + Send + 'static, Self::Error>, + > + Send { + futures::future::ready(Result::, _>::Err( + ErrorSessionManagementNotSupported, + )) + } +} diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs new file mode 100644 index 00000000..c6ea1ef2 --- /dev/null +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -0,0 +1,406 @@ +use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration}; + +use bytes::Bytes; +use futures::{StreamExt, future::BoxFuture}; +use http::{Method, Request, Response, header::ALLOW}; +use http_body::Body; +use http_body_util::{BodyExt, Full, combinators::UnsyncBoxBody}; +use tokio_stream::wrappers::ReceiverStream; + +use super::session::SessionManager; +use crate::{ + RoleServer, + model::ClientJsonRpcMessage, + serve_server, + service::serve_directly, + transport::{ + OneshotTransport, TransportAdapterIdentity, + common::{ + http_header::{ + EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, + }, + sever_side_http::{ + BoxResponse, ServerSseMessage, accepted_response, expect_json, + internal_error_response, sse_stream_response, + }, + }, + }, +}; + +#[derive(Debug, Clone)] +pub struct StreamableHttpServerConfig { + /// The ping message duration for SSE connections. + pub sse_keep_alive: Option, + /// If true, the server will create a session for each request and keep it alive. + pub stateful_mode: bool, +} + +impl Default for StreamableHttpServerConfig { + fn default() -> Self { + Self { + sse_keep_alive: Some(Duration::from_secs(15)), + stateful_mode: true, + } + } +} + +pub struct StreamableHttpService { + pub config: StreamableHttpServerConfig, + session_manager: Arc, + service_factory: Arc S + Send + Sync>, +} + +impl Clone for StreamableHttpService { + fn clone(&self) -> Self { + Self { + config: self.config.clone(), + session_manager: self.session_manager.clone(), + service_factory: self.service_factory.clone(), + } + } +} + +impl tower_service::Service> for StreamableHttpService +where + RequestBody: Body + Send + 'static, + S: crate::Service, + M: SessionManager, + RequestBody::Error: Display, + RequestBody::Data: Send + 'static, +{ + type Response = BoxResponse; + type Error = Infallible; + type Future = BoxFuture<'static, Result>; + fn call(&mut self, req: http::Request) -> Self::Future { + let service = self.clone(); + Box::pin(async move { + let response = service.handle(req).await; + Ok(response) + }) + } + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } +} + +impl StreamableHttpService +where + S: crate::Service + Send + 'static, + M: SessionManager, +{ + pub fn new( + service_factory: impl Fn() -> S + Send + Sync + 'static, + session_manager: Arc, + config: StreamableHttpServerConfig, + ) -> Self { + Self { + config, + session_manager, + service_factory: Arc::new(service_factory), + } + } + fn get_service(&self) -> S { + (self.service_factory)() + } + pub async fn handle(&self, request: Request) -> Response> + where + B: Body + Send + 'static, + B::Error: Display, + { + let method = request.method().clone(); + let result = match method { + Method::GET => self.handle_get(request).await, + Method::POST => self.handle_post(request).await, + Method::DELETE => self.handle_delete(request).await, + _ => { + // Handle other methods or return an error + let response = Response::builder() + .status(http::StatusCode::METHOD_NOT_ALLOWED) + .header(ALLOW, "GET, POST, DELETE") + .body(Full::new(Bytes::from("Method Not Allowed")).boxed_unsync()) + .expect("valid response"); + return response; + } + }; + match result { + Ok(response) => response, + Err(response) => response, + } + } + async fn handle_get(&self, request: Request) -> Result + where + B: Body + Send + 'static, + B::Error: Display, + { + // check accept header + if !request + .headers() + .get(http::header::ACCEPT) + .and_then(|header| header.to_str().ok()) + .is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE)) + { + return Ok(Response::builder() + .status(http::StatusCode::NOT_ACCEPTABLE) + .body( + Full::new(Bytes::from( + "Not Acceptable: Client must accept text/event-stream", + )) + .boxed_unsync(), + ) + .expect("valid response")); + } + // check session id + let session_id = request + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned().into()); + let Some(session_id) = session_id else { + // unauthorized + return Ok(Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed_unsync()) + .expect("valid response")); + }; + // check if session exists + let has_session = self + .session_manager + .has_session(&session_id) + .await + .map_err(internal_error_response("check session"))?; + if !has_session { + // unauthorized + return Ok(Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed_unsync()) + .expect("valid response")); + } + // check if last event id is provided + let last_event_id = request + .headers() + .get(HEADER_LAST_EVENT_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned()); + if let Some(last_event_id) = last_event_id { + // check if session has this event id + let stream = self + .session_manager + .resume(&session_id, last_event_id) + .await + .map_err(internal_error_response("resume session"))?; + Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + } else { + // create standalone stream + let stream = self + .session_manager + .create_stantalone_stream(&session_id) + .await + .map_err(internal_error_response("create standalone stream"))?; + Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + } + } + + async fn handle_post(&self, request: Request) -> Result + where + B: Body + Send + 'static, + B::Error: Display, + { + // check accept header + if !request + .headers() + .get(http::header::ACCEPT) + .and_then(|header| header.to_str().ok()) + .is_some_and(|header| { + header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE) + }) + { + return Ok(Response::builder() + .status(http::StatusCode::NOT_ACCEPTABLE) + .body(Full::new(Bytes::from("Not Acceptable: Client must accept both application/json and text/event-stream")).boxed_unsync()) + .expect("valid response")); + } + + // check content type + if !request + .headers() + .get(http::header::CONTENT_TYPE) + .and_then(|header| header.to_str().ok()) + .is_some_and(|header| header.starts_with(JSON_MIME_TYPE)) + { + return Ok(Response::builder() + .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE) + .body( + Full::new(Bytes::from( + "Unsupported Media Type: Content-Type must be application/json", + )) + .boxed_unsync(), + ) + .expect("valid response")); + } + + // json deserialize request body + let (part, body) = request.into_parts(); + let message = match expect_json(body).await { + Ok(message) => message, + Err(response) => return Ok(response), + }; + + if self.config.stateful_mode { + // do we have a session id? + let session_id = part + .headers + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()); + if let Some(session_id) = session_id { + let session_id = session_id.to_owned().into(); + let has_session = self + .session_manager + .has_session(&session_id) + .await + .map_err(internal_error_response("check session"))?; + if !has_session { + // unauthorized + return Ok(Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .body( + Full::new(Bytes::from("Unauthorized: Session not found")) + .boxed_unsync(), + ) + .expect("valid response")); + } + let stream = self + .session_manager + .create_stream(&session_id, message) + .await + .map_err(internal_error_response("get session"))?; + Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + } else { + let (session_id, transport) = self + .session_manager + .create_session() + .await + .map_err(internal_error_response("create session"))?; + let service = self.get_service(); + // spawn a task to serve the session + tokio::spawn({ + let session_manager = self.session_manager.clone(); + let session_id = session_id.clone(); + async move { + let service = serve_server::( + service, transport, + ) + .await; + match service { + Ok(service) => { + // on service created + let _ = service.waiting().await; + } + Err(e) => { + tracing::error!("Failed to create service: {e}"); + } + } + let _ = session_manager + .close_session(&session_id) + .await + .inspect_err(|e| { + tracing::error!("Failed to close session {session_id}: {e}"); + }); + } + }); + // get initialize response + let response = self + .session_manager + .initialize_session(&session_id, message) + .await + .map_err(internal_error_response("create stream"))?; + let mut response = sse_stream_response( + futures::stream::once({ + async move { + ServerSseMessage { + event_id: None, + message: response.into(), + } + } + }), + self.config.sse_keep_alive, + ); + + response.headers_mut().insert( + HEADER_SESSION_ID, + session_id + .parse() + .map_err(internal_error_response("create session id header"))?, + ); + Ok(response) + } + } else { + let service = self.get_service(); + match message { + ClientJsonRpcMessage::Request(request) => { + let (transport, receiver) = + OneshotTransport::::new(ClientJsonRpcMessage::Request(request)); + let service = serve_directly(service, transport, None); + tokio::spawn(async move { + // on service created + let _ = service.waiting().await; + }); + Ok(sse_stream_response( + ReceiverStream::new(receiver).map(|message| { + tracing::info!(?message); + ServerSseMessage { + event_id: None, + message: message.into(), + } + }), + self.config.sse_keep_alive, + )) + } + ClientJsonRpcMessage::Notification(notification) => { + service + .handle_notification(notification.notification) + .await + .map_err(internal_error_response("handle notification"))?; + Ok(accepted_response()) + } + ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()), + ClientJsonRpcMessage::Error(_json_rpc_error) => Ok(accepted_response()), + _ => Ok(Response::builder() + .status(http::StatusCode::NOT_IMPLEMENTED) + .body( + Full::new(Bytes::from("Batch requests are not supported yet")) + .boxed_unsync(), + ) + .expect("valid response")), + } + } + } + + async fn handle_delete(&self, request: Request) -> Result + where + B: Body + Send + 'static, + B::Error: Display, + { + // check session id + let session_id = request + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned().into()); + let Some(session_id) = session_id else { + // unauthorized + return Ok(Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed_unsync()) + .expect("valid response")); + }; + // close session + self.session_manager + .close_session(&session_id) + .await + .map_err(internal_error_response("close session"))?; + Ok(accepted_response()) + } +} diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 50f4293d..0272050d 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -2,10 +2,14 @@ use rmcp::{ ServiceExt, service::QuitReason, transport::{ - ConfigureCommandExt, SseServer, StreamableHttpClientTransport, TokioChildProcess, - streamable_http_server::axum::StreamableHttpServer, + ConfigureCommandExt, SseServer, StreamableHttpClientTransport, StreamableHttpServerConfig, + TokioChildProcess, + streamable_http_server::{ + session::local::LocalSessionManager, tower::StreamableHttpService, + }, }, }; +use tokio_util::sync::CancellationToken; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod common; use common::calculator::Calculator; @@ -90,10 +94,26 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { .wait() .await?; - let ct = StreamableHttpServer::serve(STREAMABLE_HTTP_BIND_ADDRESS.parse()?) - .await? - .with_service(Calculator::default); - + let service: StreamableHttpService = + StreamableHttpService::new( + Calculator::default, + Default::default(), + StreamableHttpServerConfig { + stateful_mode: true, + sse_keep_alive: None, + }, + ); + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind(STREAMABLE_HTTP_BIND_ADDRESS).await?; + let ct = CancellationToken::new(); + let handle = tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); let exit_status = tokio::process::Command::new("node") .arg("tests/test_with_js/streamable_client.js") .spawn()? @@ -101,6 +121,7 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { .await?; assert!(exit_status.success()); ct.cancel(); + handle.await?; Ok(()) } diff --git a/crates/rmcp/tests/test_with_js/streamable_client.js b/crates/rmcp/tests/test_with_js/streamable_client.js index b22acba3..99826131 100644 --- a/crates/rmcp/tests/test_with_js/streamable_client.js +++ b/crates/rmcp/tests/test_with_js/streamable_client.js @@ -1,7 +1,7 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; -const transport = new StreamableHTTPClientTransport(new URL(`http://127.0.0.1:8001/`)); +const transport = new StreamableHTTPClientTransport(new URL(`http://127.0.0.1:8001/mcp/`)); const client = new Client( { diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index fe371b00..776b5f9e 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -7,8 +7,20 @@ edition = "2024" publish = false [dependencies] -rmcp= { path = "../../crates/rmcp", features = ["server", "transport-sse-server", "transport-io", "transport-streamable-http-server", "auth"] } -tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread", "io-std", "signal"] } +rmcp = { path = "../../crates/rmcp", features = [ + "server", + "transport-sse-server", + "transport-io", + "transport-streamable-http-server", + "auth", +] } +tokio = { version = "1", features = [ + "macros", + "rt", + "rt-multi-thread", + "io-std", + "signal", +] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" anyhow = "1.0" @@ -26,9 +38,10 @@ reqwest = { version = "0.12", features = ["json"] } chrono = "0.4" uuid = { version = "1.6", features = ["v4", "serde"] } serde_urlencoded = "0.7" -askama = { version = "0.14"} +askama = { version = "0.14" } tower-http = { version = "0.6", features = ["cors"] } - +hyper = { version = "1" } +hyper-util = { version = "0", features = ["server"] } [dev-dependencies] tokio-stream = { version = "0.1" } @@ -60,4 +73,8 @@ path = "src/complex_auth_sse.rs" [[example]] name = "servers_simple_auth_sse" -path = "src/simple_auth_sse.rs" \ No newline at end of file +path = "src/simple_auth_sse.rs" + +[[example]] +name = "counter_hyper_streamable_http" +path = "src/counter_hyper_streamable_http.rs" diff --git a/examples/servers/README.md b/examples/servers/README.md index fc55506d..c80b130a 100644 --- a/examples/servers/README.md +++ b/examples/servers/README.md @@ -38,12 +38,19 @@ A minimal server example using stdio transport. ### Counter Streamable HTTP Server (`counter_streamhttp.rs`) -A server using streamable HTTP transport for MCP communication. +A server using streamable HTTP transport for MCP communication, with axum. - Runs on HTTP with streaming capabilities - Provides counter tools via HTTP streaming - Demonstrates streamable HTTP transport configuration +### Counter Streamable HTTP Server with Hyper (`counter_hyper_streamable_http.rs`) + +A server using streamable HTTP transport for MCP communication, with hyper. +- Runs on HTTP with streaming capabilities +- Provides counter tools via HTTP streaming +- Demonstrates streamable HTTP transport configuration + ### Complex OAuth SSE Server (`complex_auth_sse.rs`) A comprehensive example demonstrating OAuth 2.0 integration with MCP servers. diff --git a/examples/servers/src/counter_hyper_streamable_http.rs b/examples/servers/src/counter_hyper_streamable_http.rs new file mode 100644 index 00000000..fefb3d47 --- /dev/null +++ b/examples/servers/src/counter_hyper_streamable_http.rs @@ -0,0 +1,36 @@ +mod common; +use common::counter::Counter; +use hyper_util::{ + rt::{TokioExecutor, TokioIo}, + server::conn::auto::Builder, + service::TowerToHyperService, +}; +use rmcp::transport::streamable_http_server::{ + StreamableHttpService, session::local::LocalSessionManager, +}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let service = TowerToHyperService::new(StreamableHttpService::new( + Counter::new, + LocalSessionManager::default().into(), + Default::default(), + )); + // GET /hello/warp => 200 OK with body "Hello, warp!" + let listener = tokio::net::TcpListener::bind("[::1]:8080").await?; + loop { + let io = tokio::select! { + _ = tokio::signal::ctrl_c() => break, + accept = listener.accept() => { + TokioIo::new(accept?.0) + } + }; + let service = service.clone(); + tokio::spawn(async move { + let _result = Builder::new(TokioExecutor::default()) + .serve_connection(io, service) + .await; + }); + } + Ok(()) +} diff --git a/examples/servers/src/counter_streamhttp.rs b/examples/servers/src/counter_streamhttp.rs index e65be6b1..f4fa1d6c 100644 --- a/examples/servers/src/counter_streamhttp.rs +++ b/examples/servers/src/counter_streamhttp.rs @@ -1,4 +1,6 @@ -use rmcp::transport::streamable_http_server::axum::StreamableHttpServer; +use rmcp::transport::streamable_http_server::{ + StreamableHttpService, session::local::LocalSessionManager, +}; use tracing_subscriber::{ layer::SubscriberExt, util::SubscriberInitExt, @@ -19,11 +21,16 @@ async fn main() -> anyhow::Result<()> { .with(tracing_subscriber::fmt::layer()) .init(); - let ct = StreamableHttpServer::serve(BIND_ADDRESS.parse()?) - .await? - .with_service(Counter::new); + let service = StreamableHttpService::new( + Counter::new, + LocalSessionManager::default().into(), + Default::default(), + ); - tokio::signal::ctrl_c().await?; - ct.cancel(); + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?; + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async { tokio::signal::ctrl_c().await.unwrap() }) + .await; Ok(()) } From 6347b9862bedd3efd36f8002c175e494a3a06bd5 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Tue, 27 May 2025 14:17:33 +0800 Subject: [PATCH 2/4] fix: fix typo --- crates/rmcp/src/transport/common.rs | 2 +- .../common/{sever_side_http.rs => server_side_http.rs} | 0 crates/rmcp/src/transport/sse_server.rs | 2 +- crates/rmcp/src/transport/streamable_http_server/session.rs | 6 +++--- .../src/transport/streamable_http_server/session/local.rs | 4 ++-- .../src/transport/streamable_http_server/session/never.rs | 2 +- crates/rmcp/src/transport/streamable_http_server/tower.rs | 4 ++-- 7 files changed, 10 insertions(+), 10 deletions(-) rename crates/rmcp/src/transport/common/{sever_side_http.rs => server_side_http.rs} (100%) diff --git a/crates/rmcp/src/transport/common.rs b/crates/rmcp/src/transport/common.rs index 03c78266..401c6f2d 100644 --- a/crates/rmcp/src/transport/common.rs +++ b/crates/rmcp/src/transport/common.rs @@ -2,7 +2,7 @@ feature = "transport-streamable-http-server", feature = "transport-sse-server" ))] -pub mod sever_side_http; +pub mod server_side_http; pub mod http_header; diff --git a/crates/rmcp/src/transport/common/sever_side_http.rs b/crates/rmcp/src/transport/common/server_side_http.rs similarity index 100% rename from crates/rmcp/src/transport/common/sever_side_http.rs rename to crates/rmcp/src/transport/common/server_side_http.rs diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server.rs index 0fc12b2e..e1a2f6c8 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server.rs @@ -19,7 +19,7 @@ use crate::{ RoleServer, Service, model::ClientJsonRpcMessage, service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct}, - transport::common::sever_side_http::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id}, + transport::common::server_side_http::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id}, }; type TxStore = diff --git a/crates/rmcp/src/transport/streamable_http_server/session.rs b/crates/rmcp/src/transport/streamable_http_server/session.rs index 3b775fed..f2997585 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session.rs @@ -1,10 +1,10 @@ use futures::Stream; -pub use crate::transport::common::sever_side_http::SessionId; +pub use crate::transport::common::server_side_http::SessionId; use crate::{ RoleServer, model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, - transport::common::sever_side_http::ServerSseMessage, + transport::common::server_side_http::ServerSseMessage, }; pub mod local; @@ -33,7 +33,7 @@ pub trait SessionManager: Send + Sync + 'static { ) -> impl Future< Output = Result + Send + 'static, Self::Error>, > + Send; - fn create_stantalone_stream( + fn create_standalone_stream( &self, id: &SessionId, ) -> impl Future< diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs index 468e12bd..843a2a53 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/local.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -23,7 +23,7 @@ use crate::{ }, transport::{ WorkerTransport, - common::sever_side_http::{SessionId, session_id}, + common::server_side_http::{SessionId, session_id}, worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest}, }, }; @@ -91,7 +91,7 @@ impl SessionManager for LocalSessionManager { Ok(ReceiverStream::new(receiver.inner)) } - async fn create_stantalone_stream( + async fn create_standalone_stream( &self, id: &SessionId, ) -> Result + Send + 'static, Self::Error> { diff --git a/crates/rmcp/src/transport/streamable_http_server/session/never.rs b/crates/rmcp/src/transport/streamable_http_server/session/never.rs index a7e84a58..58abfaa5 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/never.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/never.rs @@ -76,7 +76,7 @@ impl SessionManager for NeverSessionManager { ErrorSessionManagementNotSupported, )) } - fn create_stantalone_stream( + fn create_standalone_stream( &self, _id: &SessionId, ) -> impl Future< diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index c6ea1ef2..966bc48d 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -19,7 +19,7 @@ use crate::{ http_header::{ EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, }, - sever_side_http::{ + server_side_http::{ BoxResponse, ServerSseMessage, accepted_response, expect_json, internal_error_response, sse_stream_response, }, @@ -196,7 +196,7 @@ where // create standalone stream let stream = self .session_manager - .create_stantalone_stream(&session_id) + .create_standalone_stream(&session_id) .await .map_err(internal_error_response("create standalone stream"))?; Ok(sse_stream_response(stream, self.config.sse_keep_alive)) From b128b892769d69c02e2e6fda3c6e8597d3d2ee83 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Tue, 27 May 2025 23:50:42 +0800 Subject: [PATCH 3/4] fix: accept non-request for streamable post handler --- .../streamable_http_server/session.rs | 5 +++ .../streamable_http_server/session/local.rs | 13 ++++++++ .../streamable_http_server/session/never.rs | 7 ++++ .../transport/streamable_http_server/tower.rs | 33 +++++++++++++++---- 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/crates/rmcp/src/transport/streamable_http_server/session.rs b/crates/rmcp/src/transport/streamable_http_server/session.rs index f2997585..f213126c 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session.rs @@ -33,6 +33,11 @@ pub trait SessionManager: Send + Sync + 'static { ) -> impl Future< Output = Result + Send + 'static, Self::Error>, > + Send; + fn accept_message( + &self, + id: &SessionId, + message: ClientJsonRpcMessage, + ) -> impl Future> + Send; fn create_standalone_stream( &self, id: &SessionId, diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs index 843a2a53..65adf76c 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/local.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -115,6 +115,19 @@ impl SessionManager for LocalSessionManager { let receiver = handle.resume(last_event_id.parse()?).await?; Ok(ReceiverStream::new(receiver.inner)) } + + async fn accept_message( + &self, + id: &SessionId, + message: ClientJsonRpcMessage, + ) -> Result<(), Self::Error> { + let sessions = self.sessions.read().await; + let handle = sessions + .get(id) + .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?; + handle.push_message(message, None).await?; + Ok(()) + } } /// `/request_id>` diff --git a/crates/rmcp/src/transport/streamable_http_server/session/never.rs b/crates/rmcp/src/transport/streamable_http_server/session/never.rs index 58abfaa5..436d4cfc 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/never.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/never.rs @@ -97,4 +97,11 @@ impl SessionManager for NeverSessionManager { ErrorSessionManagementNotSupported, )) } + fn accept_message( + &self, + _id: &SessionId, + _message: ClientJsonRpcMessage, + ) -> impl Future> + Send { + futures::future::ready(Err(ErrorSessionManagementNotSupported)) + } } diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 966bc48d..f9c0a69f 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -271,12 +271,33 @@ where ) .expect("valid response")); } - let stream = self - .session_manager - .create_stream(&session_id, message) - .await - .map_err(internal_error_response("get session"))?; - Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + match message { + ClientJsonRpcMessage::Request(_) => { + let stream = self + .session_manager + .create_stream(&session_id, message) + .await + .map_err(internal_error_response("get session"))?; + Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + } + ClientJsonRpcMessage::Notification(_) + | ClientJsonRpcMessage::Response(_) + | ClientJsonRpcMessage::Error(_) => { + // handle notification + self.session_manager + .accept_message(&session_id, message) + .await + .map_err(internal_error_response("accept message"))?; + Ok(accepted_response()) + } + _ => Ok(Response::builder() + .status(http::StatusCode::NOT_IMPLEMENTED) + .body( + Full::new(Bytes::from("Batch requests are not supported yet")) + .boxed_unsync(), + ) + .expect("valid response")), + } } else { let (session_id, transport) = self .session_manager From e8b0459d9b49f59eafa9f1031ce62e2d427537bc Mon Sep 17 00:00:00 2001 From: 4t145 Date: Wed, 28 May 2025 00:01:01 +0800 Subject: [PATCH 4/4] fix: remove copied document --- examples/servers/src/counter_hyper_streamable_http.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/servers/src/counter_hyper_streamable_http.rs b/examples/servers/src/counter_hyper_streamable_http.rs index fefb3d47..c9d2a3e6 100644 --- a/examples/servers/src/counter_hyper_streamable_http.rs +++ b/examples/servers/src/counter_hyper_streamable_http.rs @@ -16,7 +16,6 @@ async fn main() -> anyhow::Result<()> { LocalSessionManager::default().into(), Default::default(), )); - // GET /hello/warp => 200 OK with body "Hello, warp!" let listener = tokio::net::TcpListener::bind("[::1]:8080").await?; loop { let io = tokio::select! {