diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index 252c2916..c6407ac7 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -119,6 +119,13 @@ impl Default for ProtocolVersion { Self::LATEST } } + +impl std::fmt::Display for ProtocolVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + impl ProtocolVersion { pub const V_2025_03_26: Self = Self(Cow::Borrowed("2025-03-26")); pub const V_2024_11_05: Self = Self(Cow::Borrowed("2024-11-05")); diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 84d62beb..56edead8 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -77,6 +77,7 @@ pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone { type PeerNot: TryInto + From + TransferObject; + type InitializeError; const IS_CLIENT: bool; type Info: TransferObject; type PeerInfo: TransferObject; @@ -113,7 +114,7 @@ pub trait ServiceExt: Service + Sized { fn serve( self, transport: T, - ) -> impl Future, E>> + Send + ) -> impl Future, R::InitializeError>> + Send where T: IntoTransport, E: std::error::Error + From + Send + Sync + 'static, @@ -125,7 +126,7 @@ pub trait ServiceExt: Service + Sized { self, transport: T, ct: CancellationToken, - ) -> impl Future, E>> + Send + ) -> impl Future, R::InitializeError>> + Send where T: IntoTransport, E: std::error::Error + From + Send + Sync + 'static, @@ -469,7 +470,7 @@ pub async fn serve_directly( service: S, transport: T, peer_info: R::PeerInfo, -) -> Result, E> +) -> RunningService where R: ServiceRole, S: Service, @@ -485,7 +486,7 @@ pub async fn serve_directly_with_ct( transport: T, peer_info: R::PeerInfo, ct: CancellationToken, -) -> Result, E> +) -> RunningService where R: ServiceRole, S: Service, @@ -503,7 +504,7 @@ async fn serve_inner( peer: Peer, mut peer_rx: tokio::sync::mpsc::Receiver>, ct: CancellationToken, -) -> Result, E> +) -> RunningService where R: ServiceRole, S: Service, @@ -788,10 +789,10 @@ where tracing::info!(?quit_reason, "serve finished"); quit_reason }); - Ok(RunningService { + RunningService { service, peer: peer_return, handle, dg: ct.drop_guard(), - }) + } } diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index df667728..9c43f83b 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use thiserror::Error; use super::*; @@ -19,7 +21,7 @@ use crate::model::{ /// /// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result, ClientError>` #[derive(Error, Debug)] -pub enum ClientError { +pub enum ClientInitializeError { #[error("expect initialized response, but received: {0:?}")] ExpectedInitResponse(Option), @@ -32,30 +34,32 @@ pub enum ClientError { #[error("connection closed: {0}")] ConnectionClosed(String), - #[error("IO error: {0}")] - Io(#[from] std::io::Error), + #[error("Send message error {error}, when {context}")] + TransportError { + error: E, + context: Cow<'static, str>, + }, } /// Helper function to get the next message from the stream -async fn expect_next_message( +async fn expect_next_message( transport: &mut T, context: &str, -) -> Result +) -> Result> where T: Transport, { transport .receive() .await - .ok_or_else(|| ClientError::ConnectionClosed(context.to_string())) - .map_err(|e| ClientError::Io(std::io::Error::new(std::io::ErrorKind::Other, e))) + .ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string())) } /// Helper function to expect a response from the stream -async fn expect_response( +async fn expect_response( transport: &mut T, context: &str, -) -> Result<(ServerResult, RequestId), ClientError> +) -> Result<(ServerResult, RequestId), ClientInitializeError> where T: Transport, { @@ -63,7 +67,7 @@ where match msg { ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)), - _ => Err(ClientError::ExpectedInitResponse(Some(msg))), + _ => Err(ClientInitializeError::ExpectedInitResponse(Some(msg))), } } @@ -79,7 +83,7 @@ impl ServiceRole for RoleClient { type PeerNot = ServerNotification; type Info = ClientInfo; type PeerInfo = ServerInfo; - + type InitializeError = ClientInitializeError; const IS_CLIENT: bool = true; } @@ -90,7 +94,7 @@ impl> ServiceExt for S { self, transport: T, ct: CancellationToken, - ) -> impl Future, E>> + Send + ) -> impl Future, ClientInitializeError>> + Send where T: IntoTransport, E: std::error::Error + From + Send + Sync + 'static, @@ -103,7 +107,7 @@ impl> ServiceExt for S { pub async fn serve_client( service: S, transport: T, -) -> Result, E> +) -> Result, ClientInitializeError> where S: Service, T: IntoTransport, @@ -116,7 +120,7 @@ pub async fn serve_client_with_ct( service: S, transport: T, ct: CancellationToken, -) -> Result, E> +) -> Result, ClientInitializeError> where S: Service, T: IntoTransport, @@ -125,14 +129,6 @@ where let mut transport = transport.into_transport(); let id_provider = >::default(); - // Convert ClientError to std::io::Error, then to E - let handle_client_error = |e: ClientError| -> E { - match e { - ClientError::Io(io_err) => io_err.into(), - other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(), - } - }; - // service let id = id_provider.next_request_id(); let init_request = InitializeRequest { @@ -145,23 +141,23 @@ where ClientRequest::InitializeRequest(init_request), id.clone(), )) - .await?; - - let (response, response_id) = expect_response(&mut transport, "initialize response") .await - .map_err(handle_client_error)?; + .map_err(|error| ClientInitializeError::TransportError { + error, + context: "send initialize request".into(), + })?; + + let (response, response_id) = expect_response(&mut transport, "initialize response").await?; if id != response_id { - return Err(handle_client_error(ClientError::ConflictInitResponseId( + return Err(ClientInitializeError::ConflictInitResponseId( id, response_id, - ))); + )); } let ServerResult::InitializeResult(initialize_result) = response else { - return Err(handle_client_error(ClientError::ExpectedInitResult(Some( - response, - )))); + return Err(ClientInitializeError::ExpectedInitResult(Some(response))); }; // send notification @@ -171,9 +167,15 @@ where extensions: Default::default(), }), ); - transport.send(notification).await?; + transport + .send(notification) + .await + .map_err(|error| ClientInitializeError::TransportError { + error, + context: "send initialized notification".into(), + })?; let (peer, peer_rx) = Peer::new(id_provider, initialize_result); - serve_inner(service, transport, peer, peer_rx, ct).await + Ok(serve_inner(service, transport, peer, peer_rx, ct).await) } macro_rules! method { diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index e0d8b91f..d40241d3 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use thiserror::Error; use super::*; @@ -6,9 +8,9 @@ use crate::model::{ ClientNotification, ClientRequest, ClientResult, CreateMessageRequest, CreateMessageRequestParam, CreateMessageResult, ErrorData, ListRootsRequest, ListRootsResult, LoggingMessageNotification, LoggingMessageNotificationParam, ProgressNotification, - ProgressNotificationParam, PromptListChangedNotification, ResourceListChangedNotification, - ResourceUpdatedNotification, ResourceUpdatedNotificationParam, ServerInfo, ServerNotification, - ServerRequest, ServerResult, ToolListChangedNotification, + ProgressNotificationParam, PromptListChangedNotification, ProtocolVersion, + ResourceListChangedNotification, ResourceUpdatedNotification, ResourceUpdatedNotificationParam, + ServerInfo, ServerNotification, ServerRequest, ServerResult, ToolListChangedNotification, }; #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] @@ -23,6 +25,8 @@ impl ServiceRole for RoleServer { type PeerNot = ClientNotification; type Info = ServerInfo; type PeerInfo = ClientInfo; + + type InitializeError = ServerInitializeError; const IS_CLIENT: bool = false; } @@ -30,12 +34,12 @@ impl ServiceRole for RoleServer { /// /// if you want to handle the error, you can use `serve_server_with_ct` or `serve_server` with `Result, ServerError>` #[derive(Error, Debug)] -pub enum ServerError { +pub enum ServerInitializeError { #[error("expect initialized request, but received: {0:?}")] - ExpectedInitRequest(Option), + ExpectedInitializeRequest(Option), #[error("expect initialized notification, but received: {0:?}")] - ExpectedInitNotification(Option), + ExpectedInitializedNotification(Option), #[error("connection closed: {0}")] ConnectionClosed(String), @@ -46,8 +50,14 @@ pub enum ServerError { #[error("initialize failed: {0}")] InitializeFailed(ErrorData), - #[error("IO error: {0}")] - Io(#[from] std::io::Error), + #[error("unsupported protocol version: {0}")] + UnsupportedProtocolVersion(ProtocolVersion), + + #[error("Send message error {error}, when {context}")] + TransportError { + error: E, + context: Cow<'static, str>, + }, } pub type ClientSink = Peer; @@ -57,7 +67,7 @@ impl> ServiceExt for S { self, transport: T, ct: CancellationToken, - ) -> impl Future, E>> + Send + ) -> impl Future, ServerInitializeError>> + Send where T: IntoTransport, E: std::error::Error + From + Send + Sync + 'static, @@ -70,7 +80,7 @@ impl> ServiceExt for S { pub async fn serve_server( service: S, transport: T, -) -> Result, E> +) -> Result, ServerInitializeError> where S: Service, T: IntoTransport, @@ -80,52 +90,56 @@ where } /// Helper function to get the next message from the stream -async fn expect_next_message( +async fn expect_next_message( transport: &mut T, context: &str, -) -> Result +) -> Result> where T: Transport, { transport .receive() .await - .ok_or_else(|| ServerError::ConnectionClosed(context.to_string())) + .ok_or_else(|| ServerInitializeError::ConnectionClosed(context.to_string())) } /// Helper function to expect a request from the stream -async fn expect_request( +async fn expect_request( transport: &mut T, context: &str, -) -> Result<(ClientRequest, RequestId), ServerError> +) -> Result<(ClientRequest, RequestId), ServerInitializeError> where T: Transport, { let msg = expect_next_message(transport, context).await?; let msg_clone = msg.clone(); msg.into_request() - .ok_or(ServerError::ExpectedInitRequest(Some(msg_clone))) + .ok_or(ServerInitializeError::ExpectedInitializeRequest(Some( + msg_clone, + ))) } /// Helper function to expect a notification from the stream -async fn expect_notification( +async fn expect_notification( transport: &mut T, context: &str, -) -> Result +) -> Result> where T: Transport, { let msg = expect_next_message(transport, context).await?; let msg_clone = msg.clone(); msg.into_notification() - .ok_or(ServerError::ExpectedInitNotification(Some(msg_clone))) + .ok_or(ServerInitializeError::ExpectedInitializedNotification( + Some(msg_clone), + )) } pub async fn serve_server_with_ct( service: S, transport: T, ct: CancellationToken, -) -> Result, E> +) -> Result, ServerInitializeError> where S: Service, T: IntoTransport, @@ -134,23 +148,13 @@ where let mut transport = transport.into_transport(); let id_provider = >::default(); - // Convert ServerError to std::io::Error, then to E - let handle_server_error = |e: ServerError| -> E { - match e { - ServerError::Io(io_err) => io_err.into(), - other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(), - } - }; - // Get initialize request - let (request, id) = expect_request(&mut transport, "initialized request") - .await - .map_err(handle_server_error)?; + let (request, id) = expect_request(&mut transport, "initialized request").await?; let ClientRequest::InitializeRequest(peer_info) = &request else { - return Err(handle_server_error(ServerError::ExpectedInitRequest(Some( + return Err(ServerInitializeError::ExpectedInitializeRequest(Some( ClientJsonRpcMessage::request(request, id), - )))); + ))); }; let (peer, peer_rx) = Peer::new(id_provider, peer_info.params.clone()); let context = RequestContext { @@ -165,24 +169,24 @@ where let mut init_response = match init_response { Ok(ServerResult::InitializeResult(init_response)) => init_response, Ok(result) => { - return Err(handle_server_error( - ServerError::UnexpectedInitializeResponse(result), - )); + return Err(ServerInitializeError::UnexpectedInitializeResponse(result)); } Err(e) => { transport .send(ServerJsonRpcMessage::error(e.clone(), id)) - .await?; - return Err(handle_server_error(ServerError::InitializeFailed(e))); + .await + .map_err(|error| ServerInitializeError::TransportError { + error, + context: "sending error response".into(), + })?; + return Err(ServerInitializeError::InitializeFailed(e)); } }; - let protocol_version = match peer_info - .params - .protocol_version + let peer_protocol_version = peer_info.params.protocol_version.clone(); + let protocol_version = match peer_protocol_version .partial_cmp(&init_response.protocol_version) - .ok_or(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "unsupported protocol version", + .ok_or(ServerInitializeError::UnsupportedProtocolVersion( + peer_protocol_version, ))? { std::cmp::Ordering::Less => peer_info.params.protocol_version.clone(), _ => init_response.protocol_version, @@ -193,20 +197,22 @@ where ServerResult::InitializeResult(init_response), id, )) - .await?; + .await + .map_err(|error| ServerInitializeError::TransportError { + error, + context: "sending initialize response".into(), + })?; // Wait for initialize notification - let notification = expect_notification(&mut transport, "initialize notification") - .await - .map_err(handle_server_error)?; + let notification = expect_notification(&mut transport, "initialize notification").await?; let ClientNotification::InitializedNotification(_) = notification else { - return Err(handle_server_error(ServerError::ExpectedInitNotification( + return Err(ServerInitializeError::ExpectedInitializedNotification( Some(ClientJsonRpcMessage::notification(notification)), - ))); + )); }; let _ = service.handle_notification(notification).await; // Continue processing service - serve_inner(service, transport, peer, peer_rx, ct).await + Ok(serve_inner(service, transport, peer, peer_rx, ct).await) } macro_rules! method { diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server.rs index 7dce63c1..2bc8ec86 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server.rs @@ -288,7 +288,10 @@ impl SseServer { let service = service_provider(); let ct = self.config.ct.child_token(); tokio::spawn(async move { - let server = service.serve_with_ct(transport, ct).await?; + let server = service + .serve_with_ct(transport, ct) + .await + .map_err(std::io::Error::other)?; server.waiting().await?; tokio::io::Result::Ok(()) }); diff --git a/crates/rmcp/src/transport/streamable_http_server/axum.rs b/crates/rmcp/src/transport/streamable_http_server/axum.rs index 9bc70b49..016dee6b 100644 --- a/crates/rmcp/src/transport/streamable_http_server/axum.rs +++ b/crates/rmcp/src/transport/streamable_http_server/axum.rs @@ -327,7 +327,10 @@ impl StreamableHttpServer { let service = service_provider(); let ct = self.config.ct.child_token(); tokio::spawn(async move { - let server = service.serve_with_ct(transport, ct).await?; + let server = service + .serve_with_ct(transport, ct) + .await + .map_err(tokio::io::Error::other)?; server.waiting().await?; tokio::io::Result::Ok(()) });