Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions crates/rmcp/src/transport/sse_server.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use std::{collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration};

use axum::{
Json, Router,
extract::{Query, State},
http::StatusCode,
response::{
Response,
sse::{Event, Sse},
sse::{Event, KeepAlive, Sse},
},
routing::{get, post},
};
use futures::{Sink, SinkExt, Stream, StreamExt};
use tokio::io;
use futures::{Sink, SinkExt, Stream};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::{CancellationToken, PollSender};
use tracing::Instrument;
Expand All @@ -26,28 +25,33 @@ type TxStore =
Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::mpsc::Sender<ClientJsonRpcMessage>>>>;
pub type TransportReceiver = ReceiverStream<RxJsonRpcMessage<RoleServer>>;

const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15);

#[derive(Clone)]
struct App {
txs: TxStore,
transport_tx: tokio::sync::mpsc::UnboundedSender<SseServerTransport>,
post_path: Arc<str>,
sse_ping_interval: Duration,
}

impl App {
pub fn new(
post_path: String,
sse_ping_interval: Duration,
) -> (
Self,
tokio::sync::mpsc::UnboundedReceiver<SseServerTransport>,
) {
let (transport_tx, tranport_rx) = tokio::sync::mpsc::unbounded_channel();
let (transport_tx, transport_rx) = tokio::sync::mpsc::unbounded_channel();
(
Self {
txs: Default::default(),
transport_tx,
post_path: post_path.into(),
sse_ping_interval,
},
tranport_rx,
transport_rx,
)
}
}
Expand Down Expand Up @@ -87,7 +91,7 @@ async fn sse_handler(
) -> Result<Sse<impl Stream<Item = Result<Event, io::Error>>>, Response<String>> {
let session = session_id();
tracing::info!(%session, "sse connection");
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::{StreamExt, wrappers::ReceiverStream};
use tokio_util::sync::PollSender;
let (from_client_tx, from_client_rx) = tokio::sync::mpsc::channel(64);
let (to_client_tx, to_client_rx) = tokio::sync::mpsc::channel(64);
Expand All @@ -108,11 +112,12 @@ async fn sse_handler(
if transport_send_result.is_err() {
tracing::warn!("send transport out error");
let mut response =
Response::new("fail to send out trasnport, it seems server is closed".to_string());
Response::new("fail to send out transport, it seems server is closed".to_string());
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Err(response);
}
let post_path = app.post_path.as_ref();
let ping_interval = app.sse_ping_interval;
let stream = futures::stream::once(futures::future::ok(
Event::default()
.event("endpoint")
Expand All @@ -124,7 +129,7 @@ async fn sse_handler(
Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)),
}
}));
Ok(Sse::new(stream))
Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(ping_interval)))
}

pub struct SseServerTransport {
Expand Down Expand Up @@ -190,6 +195,7 @@ impl Stream for SseServerTransport {
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use futures::StreamExt;
self.stream.poll_next_unpin(cx)
}
}
Expand All @@ -200,6 +206,7 @@ pub struct SseServerConfig {
pub sse_path: String,
pub post_path: String,
pub ct: CancellationToken,
pub sse_keep_alive: Option<Duration>,
}

#[derive(Debug)]
Expand All @@ -215,6 +222,7 @@ impl SseServer {
sse_path: "/sse".to_string(),
post_path: "/message".to_string(),
ct: CancellationToken::new(),
sse_keep_alive: None,
})
.await
}
Expand All @@ -240,7 +248,10 @@ impl SseServer {
/// Warning: This function creates a new SseServer instance with the provided configuration.
/// `App.post_path` may be incorrect if using `Router` as an embedded router.
pub fn new(config: SseServerConfig) -> (SseServer, Router) {
let (app, transport_rx) = App::new(config.post_path.clone());
let (app, transport_rx) = App::new(
config.post_path.clone(),
config.sse_keep_alive.unwrap_or(DEFAULT_AUTO_PING_INTERVAL),
);
let router = Router::new()
.route(&config.sse_path, get(sse_handler))
.route(&config.post_path, post(post_event_handler))
Expand Down
1 change: 1 addition & 0 deletions examples/servers/src/axum_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ async fn main() -> anyhow::Result<()> {
sse_path: "/sse".to_string(),
post_path: "/message".to_string(),
ct: tokio_util::sync::CancellationToken::new(),
sse_keep_alive: None,
};

let (sse_server, router) = SseServer::new(config);
Expand Down
Loading