Skip to content

Commit 0c2a48a

Browse files
committed
rewrite features
1 parent b727966 commit 0c2a48a

File tree

11 files changed

+212
-229
lines changed

11 files changed

+212
-229
lines changed

crates/rmcp/Cargo.toml

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,15 @@ reqwest = ["__reqwest", "reqwest?/rustls-tls"]
7676

7777
reqwest-tls-no-provider = ["__reqwest", "reqwest?/rustls-tls-no-provider"]
7878

79-
axum = ["dep:axum"]
79+
server-side-http = [
80+
"uuid",
81+
"dep:rand",
82+
"dep:tokio-stream",
83+
"dep:http-body",
84+
"dep:http-body-util",
85+
"dep:bytes",
86+
"tower",
87+
]
8088
# SSE client
8189
client-side-sse = ["dep:sse-stream", "dep:http"]
8290

@@ -99,19 +107,11 @@ transport-child-process = [
99107
transport-sse-server = [
100108
"transport-async-rw",
101109
"transport-worker",
102-
"axum",
103-
"dep:rand",
104-
"dep:tokio-stream",
105-
"uuid",
110+
"server-side-http",
106111
]
107112
transport-streamable-http-server = [
108113
"transport-streamable-http-server-session",
109-
"axum",
110-
"uuid",
111-
"tower",
112-
"dep:http-body",
113-
"dep:http-body-util",
114-
"dep:bytes",
114+
"server-side-http",
115115
]
116116
transport-streamable-http-server-session = [
117117
"transport-async-rw",

crates/rmcp/src/transport.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, Authorized
122122
pub mod streamable_http_server;
123123
#[cfg(feature = "transport-streamable-http-server")]
124124
#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-server")))]
125-
pub use streamable_http_server::tower::{StreamableHttpService, StreamableHttpServerConfig};
125+
pub use streamable_http_server::tower::{StreamableHttpServerConfig, StreamableHttpService};
126126

127127
#[cfg(feature = "transport-streamable-http-client")]
128128
#[cfg_attr(docsrs, doc(cfg(feature = "transport-streamable-http-client")))]
@@ -169,7 +169,7 @@ impl<R, T, E> IntoTransport<R, E, TransportAdapterIdentity> for T
169169
where
170170
T: Transport<R, Error = E> + Send + 'static,
171171
R: ServiceRole,
172-
E: std::error::Error + Send + 'static,
172+
E: std::error::Error + Send + Sync + 'static,
173173
{
174174
fn into_transport(self) -> impl Transport<R, Error = E> + 'static {
175175
self

crates/rmcp/src/transport/common.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
feature = "transport-streamable-http-server",
33
feature = "transport-sse-server"
44
))]
5-
pub mod axum;
5+
pub mod sever_side_http;
66

77
pub mod http_header;
88

crates/rmcp/src/transport/common/axum.rs

Lines changed: 0 additions & 9 deletions
This file was deleted.
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
2+
3+
use bytes::{Buf, Bytes};
4+
use http::Response;
5+
use http_body::Body;
6+
use http_body_util::{BodyExt, Empty, Full, combinators::UnsyncBoxBody};
7+
use sse_stream::{KeepAlive, Sse, SseBody};
8+
9+
use crate::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
10+
11+
use super::http_header::EVENT_STREAM_MIME_TYPE;
12+
13+
pub type SessionId = Arc<str>;
14+
15+
pub fn session_id() -> SessionId {
16+
uuid::Uuid::new_v4().to_string().into()
17+
}
18+
19+
pub const DEFAULT_AUTO_PING_INTERVAL: Duration = Duration::from_secs(15);
20+
21+
pub(crate) type BoxResponse = Response<UnsyncBoxBody<Bytes, Infallible>>;
22+
23+
pub(crate) fn accecpted_response() -> Response<UnsyncBoxBody<Bytes, Infallible>> {
24+
Response::builder()
25+
.status(http::StatusCode::ACCEPTED)
26+
.body(Empty::new().boxed_unsync())
27+
.expect("valid response")
28+
}
29+
pin_project_lite::pin_project! {
30+
struct TokioTimer {
31+
#[pin]
32+
sleep: tokio::time::Sleep,
33+
}
34+
}
35+
impl Future for TokioTimer {
36+
type Output = ();
37+
38+
fn poll(
39+
self: std::pin::Pin<&mut Self>,
40+
cx: &mut std::task::Context<'_>,
41+
) -> std::task::Poll<Self::Output> {
42+
let this = self.project();
43+
this.sleep.poll(cx)
44+
}
45+
}
46+
impl sse_stream::Timer for TokioTimer {
47+
fn from_duration(duration: Duration) -> Self {
48+
Self {
49+
sleep: tokio::time::sleep(duration),
50+
}
51+
}
52+
53+
fn reset(self: std::pin::Pin<&mut Self>, when: std::time::Instant) {
54+
let this = self.project();
55+
this.sleep.reset(tokio::time::Instant::from_std(when));
56+
}
57+
}
58+
59+
#[derive(Debug, Clone)]
60+
pub struct ServerSseMessage {
61+
pub event_id: Option<String>,
62+
pub message: Arc<ServerJsonRpcMessage>,
63+
}
64+
65+
pub(crate) fn sse_stream_response(
66+
stream: impl futures::Stream<Item = ServerSseMessage> + Send + 'static,
67+
keep_alive: Option<Duration>,
68+
) -> Response<UnsyncBoxBody<Bytes, Infallible>> {
69+
use futures::StreamExt;
70+
let stream = SseBody::new(stream.map(|message| {
71+
let data = serde_json::to_string(&message.message).expect("valid message");
72+
let mut sse = Sse::default().data(data);
73+
sse.id = message.event_id;
74+
Result::<Sse, Infallible>::Ok(sse)
75+
}));
76+
let stream = match keep_alive {
77+
Some(duration) => stream
78+
.with_keep_alive::<TokioTimer>(KeepAlive::new().interval(duration))
79+
.boxed_unsync(),
80+
None => stream.boxed_unsync(),
81+
};
82+
Response::builder()
83+
.status(http::StatusCode::OK)
84+
.header(http::header::CONTENT_TYPE, EVENT_STREAM_MIME_TYPE)
85+
.header(http::header::CACHE_CONTROL, "no-cache")
86+
.body(stream)
87+
.expect("valid response")
88+
}
89+
90+
pub(crate) const fn internal_error_response<E: Display>(
91+
context: &str,
92+
) -> impl FnOnce(E) -> Response<UnsyncBoxBody<Bytes, Infallible>> {
93+
move |error| {
94+
tracing::error!("Internal server error when {context}: {error}");
95+
Response::builder()
96+
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
97+
.body(
98+
Full::new(Bytes::from(format!(
99+
"Encounter an error when {context}: {error}"
100+
)))
101+
.boxed_unsync(),
102+
)
103+
.expect("valid response")
104+
}
105+
}
106+
107+
pub(crate) async fn expect_json<B>(
108+
body: B,
109+
) -> Result<ClientJsonRpcMessage, Response<UnsyncBoxBody<Bytes, Infallible>>>
110+
where
111+
B: Body + Send + 'static,
112+
B::Error: Display,
113+
{
114+
match body.collect().await {
115+
Ok(bytes) => {
116+
match serde_json::from_reader::<_, ClientJsonRpcMessage>(bytes.aggregate().reader()) {
117+
Ok(message) => Ok(message),
118+
Err(e) => {
119+
let response = Response::builder()
120+
.status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)
121+
.body(
122+
Full::new(Bytes::from(format!("fail to deserialize request body {e}")))
123+
.boxed_unsync(),
124+
)
125+
.expect("valid response");
126+
Err(response)
127+
}
128+
}
129+
}
130+
Err(e) => {
131+
let response = Response::builder()
132+
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
133+
.body(
134+
Full::new(Bytes::from(format!("Failed to read request body: {e}")))
135+
.boxed_unsync(),
136+
)
137+
.expect("valid response");
138+
Err(response)
139+
}
140+
}
141+
}

crates/rmcp/src/transport/sse_server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::{
1919
RoleServer, Service,
2020
model::ClientJsonRpcMessage,
2121
service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct},
22-
transport::common::axum::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id},
22+
transport::common::sever_side_http::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id},
2323
};
2424

2525
type TxStore =

crates/rmcp/src/transport/streamable_http_server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
pub mod axum;
44
pub mod session;
55
pub mod tower;
6-
pub use session::{ServerSseMessage, SessionManager, SessionId};
6+
pub use session::{SessionId, SessionManager};

crates/rmcp/src/transport/streamable_http_server/session.rs

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,16 @@
1-
pub use crate::transport::common::axum::SessionId;
1+
pub use crate::transport::common::sever_side_http::SessionId;
22
use futures::Stream;
3-
use std::{
4-
borrow::Cow,
5-
collections::{HashMap, HashSet, VecDeque},
6-
num::ParseIntError,
7-
sync::Arc,
8-
};
9-
use thiserror::Error;
10-
use tokio::sync::{
11-
mpsc::{Receiver, Sender},
12-
oneshot,
13-
};
14-
use tokio_stream::wrappers::ReceiverStream;
15-
use tracing::instrument;
163

174
use crate::{
18-
RoleServer,
195
model::{
20-
CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest,
21-
JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam,
22-
ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification,
23-
},
24-
transport::{
25-
WorkerTransport,
26-
common::axum::session_id,
27-
worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest},
28-
},
6+
ClientJsonRpcMessage, ServerJsonRpcMessage,
7+
}, transport::common::sever_side_http::ServerSseMessage, RoleServer
298
};
309

3110
pub mod local;
3211
pub mod never;
3312

34-
#[derive(Debug, Clone)]
35-
pub struct ServerSseMessage {
36-
pub event_id: String,
37-
pub message: Arc<ServerJsonRpcMessage>,
38-
}
13+
3914

4015
pub trait SessionManager: Send + Sync + 'static {
4116
type Error: std::error::Error + Send + 'static;

crates/rmcp/src/transport/streamable_http_server/session/local.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
use futures::{Stream, StreamExt};
1+
use futures::Stream;
22
use std::{
3-
borrow::Cow,
43
collections::{HashMap, HashSet, VecDeque},
54
num::ParseIntError,
65
sync::Arc,
@@ -22,7 +21,7 @@ use crate::{
2221
},
2322
transport::{
2423
WorkerTransport,
25-
common::axum::{SessionId, session_id},
24+
common::sever_side_http::{SessionId, session_id},
2625
worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest},
2726
},
2827
};
@@ -48,7 +47,7 @@ impl SessionManager for LocalSessionManager {
4847
let id = session_id();
4948
let (handle, worker) = create_local_session(id.clone(), Default::default());
5049
self.sessions.write().await.insert(id.clone(), handle);
51-
return Ok((id, WorkerTransport::spawn(worker)));
50+
Ok((id, WorkerTransport::spawn(worker)))
5251
}
5352
async fn initialize_session(
5453
&self,
@@ -188,14 +187,14 @@ impl CachedTx {
188187

189188
async fn send(&mut self, message: ServerJsonRpcMessage) {
190189
let index = self.cache.back().map_or(0, |m| {
191-
m.event_id.parse::<EventId>().expect("valid event id").index + 1
190+
m.event_id.as_deref().unwrap_or_default().parse::<EventId>().expect("valid event id").index + 1
192191
});
193192
let event_id = EventId {
194193
http_request_id: self.http_request_id,
195194
index,
196195
};
197196
let message = ServerSseMessage {
198-
event_id: event_id.to_string(),
197+
event_id: Some(event_id.to_string()),
199198
message: Arc::new(message),
200199
};
201200
if self.cache.len() >= self.capacity {
@@ -206,15 +205,15 @@ impl CachedTx {
206205
}
207206
let _ = self.tx.send(message).await.inspect_err(|e| {
208207
let event_id = &e.0.event_id;
209-
tracing::trace!(%event_id, "trying to send message in a closed session")
208+
tracing::trace!(?event_id, "trying to send message in a closed session")
210209
});
211210
}
212211

213212
async fn sync(&mut self, index: usize) -> Result<(), SessionError> {
214213
let Some(front) = self.cache.front() else {
215214
return Ok(());
216215
};
217-
let front_event_id = front.event_id.parse::<EventId>()?;
216+
let front_event_id = front.event_id.as_deref().unwrap_or_default().parse::<EventId>()?;
218217
let sync_index = index.saturating_sub(front_event_id.index);
219218
if sync_index > self.cache.len() {
220219
// invalid index
@@ -223,7 +222,7 @@ impl CachedTx {
223222
for message in self.cache.iter().skip(sync_index) {
224223
let send_result = self.tx.send(message.clone()).await;
225224
if send_result.is_err() {
226-
let event_id: EventId = message.event_id.parse()?;
225+
let event_id: EventId = message.event_id.as_deref().unwrap_or_default().parse()?;
227226
return Err(SessionError::ChannelClosed(Some(event_id.index as u64)));
228227
}
229228
}
@@ -726,7 +725,6 @@ impl Worker for LocalSessionWorker {
726725
};
727726
match event {
728727
InnerEvent::FromHandler(WorkerSendRequest { message, responder }) => {
729-
tracing::info!(?message, "received message from handler");
730728
// catch response
731729
let to_unregister = match &message {
732730
crate::model::JsonRpcMessage::Response(json_rpc_response) => {

0 commit comments

Comments
 (0)