From b061ae374f7ee6170e859850045f8ecdd6463b13 Mon Sep 17 00:00:00 2001 From: = <=> Date: Tue, 1 Apr 2025 01:44:29 +0800 Subject: [PATCH 1/2] feat(sse-server): auto ping in sse stream every second to make cursor happy --- crates/rmcp/src/transport/sse_server.rs | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server.rs index ef0b0309..d661f1d3 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server.rs @@ -13,8 +13,8 @@ use axum::{ }, routing::{get, post}, }; -use futures::{Sink, SinkExt, Stream, StreamExt}; -use std::{collections::HashMap, net::SocketAddr}; +use futures::{Sink, SinkExt, Stream}; +use std::{collections::HashMap, net::SocketAddr, time::Duration}; use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::{CancellationToken, PollSender}; use tracing::Instrument; @@ -40,14 +40,14 @@ impl App { Self, tokio::sync::mpsc::UnboundedReceiver, ) { - let (transport_tx, tranport_rx) = tokio::sync::mpsc::unbounded_channel(); + let (transport_tx, transport_rx) = tokio::sync::mpsc::unbounded_channel(); ( Self { txs: Default::default(), transport_tx, post_path: post_path.into(), }, - tranport_rx, + transport_rx, ) } } @@ -85,8 +85,10 @@ async fn post_event_handler( async fn sse_handler( State(app): State, ) -> Result>>, Response> { + const AUTO_PING_INTERVAL: Duration = Duration::from_secs(1); let session = session_id(); tracing::info!(%session, "sse connection"); + use tokio_stream::StreamExt; use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::PollSender; let (from_client_tx, from_client_rx) = tokio::sync::mpsc::channel(64); @@ -108,7 +110,7 @@ async fn sse_handler( if transport_send_result.is_err() { tracing::warn!("send transport out error"); let mut response = - Response::new("fail to send out trasnport, it seems server is closed".to_string()); + Response::new("fail to send out transport, it seems server is closed".to_string()); *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; return Err(response); } @@ -118,12 +120,16 @@ async fn sse_handler( .event("endpoint") .data(format!("{post_path}?sessionId={session}")), )) - .chain(ReceiverStream::new(to_client_rx).map(|message| { - match serde_json::to_string(&message) { + .chain( + ReceiverStream::new(to_client_rx).map(|message| match serde_json::to_string(&message) { Ok(bytes) => Ok(Event::default().event("message").data(&bytes)), Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)), - } - })); + }), + ) + .merge( + tokio_stream::wrappers::IntervalStream::new(tokio::time::interval(AUTO_PING_INTERVAL)) + .map(|_| Ok(Event::default().comment("ping"))), + ); Ok(Sse::new(stream)) } @@ -190,6 +196,7 @@ impl Stream for SseServerTransport { mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { + use futures::StreamExt; self.stream.poll_next_unpin(cx) } } From 3a39f4153935c83a982bb2cfa32e3def20f59124 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Tue, 1 Apr 2025 10:15:11 +0800 Subject: [PATCH 2/2] feat(sse-server): configurable sse keep alive --- crates/rmcp/src/transport/sse_server.rs | 30 +++++++++++++++---------- examples/servers/src/axum_router.rs | 1 + 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server.rs index 24e82a50..ed04b3d0 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server.rs @@ -6,7 +6,7 @@ use axum::{ http::StatusCode, response::{ Response, - sse::{Event, Sse}, + sse::{Event, KeepAlive, Sse}, }, routing::{get, post}, }; @@ -25,16 +25,20 @@ type TxStore = Arc>>>; pub type TransportReceiver = ReceiverStream>; +const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15); + #[derive(Clone)] struct App { txs: TxStore, transport_tx: tokio::sync::mpsc::UnboundedSender, post_path: Arc, + sse_ping_interval: Duration, } impl App { pub fn new( post_path: String, + sse_ping_interval: Duration, ) -> ( Self, tokio::sync::mpsc::UnboundedReceiver, @@ -45,6 +49,7 @@ impl App { txs: Default::default(), transport_tx, post_path: post_path.into(), + sse_ping_interval, }, transport_rx, ) @@ -84,7 +89,6 @@ async fn post_event_handler( async fn sse_handler( State(app): State, ) -> Result>>, Response> { - const AUTO_PING_INTERVAL: Duration = Duration::from_secs(1); let session = session_id(); tracing::info!(%session, "sse connection"); use tokio_stream::{StreamExt, wrappers::ReceiverStream}; @@ -113,22 +117,19 @@ async fn sse_handler( return Err(response); } let post_path = app.post_path.as_ref(); + let ping_interval = app.sse_ping_interval; let stream = futures::stream::once(futures::future::ok( Event::default() .event("endpoint") .data(format!("{post_path}?sessionId={session}")), )) - .chain( - ReceiverStream::new(to_client_rx).map(|message| match serde_json::to_string(&message) { + .chain(ReceiverStream::new(to_client_rx).map(|message| { + match serde_json::to_string(&message) { Ok(bytes) => Ok(Event::default().event("message").data(&bytes)), Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)), - }), - ) - .merge( - tokio_stream::wrappers::IntervalStream::new(tokio::time::interval(AUTO_PING_INTERVAL)) - .map(|_| Ok(Event::default().comment("ping"))), - ); - Ok(Sse::new(stream)) + } + })); + Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(ping_interval))) } pub struct SseServerTransport { @@ -205,6 +206,7 @@ pub struct SseServerConfig { pub sse_path: String, pub post_path: String, pub ct: CancellationToken, + pub sse_keep_alive: Option, } #[derive(Debug)] @@ -220,6 +222,7 @@ impl SseServer { sse_path: "/sse".to_string(), post_path: "/message".to_string(), ct: CancellationToken::new(), + sse_keep_alive: None, }) .await } @@ -245,7 +248,10 @@ impl SseServer { /// Warning: This function creates a new SseServer instance with the provided configuration. /// `App.post_path` may be incorrect if using `Router` as an embedded router. pub fn new(config: SseServerConfig) -> (SseServer, Router) { - let (app, transport_rx) = App::new(config.post_path.clone()); + let (app, transport_rx) = App::new( + config.post_path.clone(), + config.sse_keep_alive.unwrap_or(DEFAULT_AUTO_PING_INTERVAL), + ); let router = Router::new() .route(&config.sse_path, get(sse_handler)) .route(&config.post_path, post(post_event_handler)) diff --git a/examples/servers/src/axum_router.rs b/examples/servers/src/axum_router.rs index fa456935..373a8a6d 100644 --- a/examples/servers/src/axum_router.rs +++ b/examples/servers/src/axum_router.rs @@ -24,6 +24,7 @@ async fn main() -> anyhow::Result<()> { sse_path: "/sse".to_string(), post_path: "/message".to_string(), ct: tokio_util::sync::CancellationToken::new(), + sse_keep_alive: None, }; let (sse_server, router) = SseServer::new(config);