Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ impl Default for ProtocolVersion {
Self::LATEST
}
}

impl std::fmt::Display for ProtocolVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

impl ProtocolVersion {
pub const V_2025_03_26: Self = Self(Cow::Borrowed("2025-03-26"));
pub const V_2024_11_05: Self = Self(Cow::Borrowed("2024-11-05"));
Expand Down
15 changes: 8 additions & 7 deletions crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone {
type PeerNot: TryInto<CancelledNotification, Error = Self::PeerNot>
+ From<CancelledNotification>
+ TransferObject;
type InitializeError<E>;
const IS_CLIENT: bool;
type Info: TransferObject;
type PeerInfo: TransferObject;
Expand Down Expand Up @@ -113,7 +114,7 @@ pub trait ServiceExt<R: ServiceRole>: Service<R> + Sized {
fn serve<T, E, A>(
self,
transport: T,
) -> impl Future<Output = Result<RunningService<R, Self>, E>> + Send
) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError<E>>> + Send
where
T: IntoTransport<R, E, A>,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
Expand All @@ -125,7 +126,7 @@ pub trait ServiceExt<R: ServiceRole>: Service<R> + Sized {
self,
transport: T,
ct: CancellationToken,
) -> impl Future<Output = Result<RunningService<R, Self>, E>> + Send
) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError<E>>> + Send
where
T: IntoTransport<R, E, A>,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
Expand Down Expand Up @@ -469,7 +470,7 @@ pub async fn serve_directly<R, S, T, E, A>(
service: S,
transport: T,
peer_info: R::PeerInfo,
) -> Result<RunningService<R, S>, E>
) -> RunningService<R, S>
where
R: ServiceRole,
S: Service<R>,
Expand All @@ -485,7 +486,7 @@ pub async fn serve_directly_with_ct<R, S, T, E, A>(
transport: T,
peer_info: R::PeerInfo,
ct: CancellationToken,
) -> Result<RunningService<R, S>, E>
) -> RunningService<R, S>
where
R: ServiceRole,
S: Service<R>,
Expand All @@ -503,7 +504,7 @@ async fn serve_inner<R, S, T, E, A>(
peer: Peer<R>,
mut peer_rx: tokio::sync::mpsc::Receiver<PeerSinkMessage<R>>,
ct: CancellationToken,
) -> Result<RunningService<R, S>, E>
) -> RunningService<R, S>
where
R: ServiceRole,
S: Service<R>,
Expand Down Expand Up @@ -788,10 +789,10 @@ where
tracing::info!(?quit_reason, "serve finished");
quit_reason
});
Ok(RunningService {
RunningService {
service,
peer: peer_return,
handle,
dg: ct.drop_guard(),
})
}
}
68 changes: 35 additions & 33 deletions crates/rmcp/src/service/client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::borrow::Cow;

use thiserror::Error;

use super::*;
Expand All @@ -19,7 +21,7 @@ use crate::model::{
///
/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
#[derive(Error, Debug)]
pub enum ClientError {
pub enum ClientInitializeError<E> {
#[error("expect initialized response, but received: {0:?}")]
ExpectedInitResponse(Option<ServerJsonRpcMessage>),

Expand All @@ -32,38 +34,40 @@ pub enum ClientError {
#[error("connection closed: {0}")]
ConnectionClosed(String),

#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Send message error {error}, when {context}")]
TransportError {
error: E,
context: Cow<'static, str>,
},
}

/// Helper function to get the next message from the stream
async fn expect_next_message<T>(
async fn expect_next_message<T, E>(
transport: &mut T,
context: &str,
) -> Result<ServerJsonRpcMessage, ClientError>
) -> Result<ServerJsonRpcMessage, ClientInitializeError<E>>
where
T: Transport<RoleClient>,
{
transport
.receive()
.await
.ok_or_else(|| ClientError::ConnectionClosed(context.to_string()))
.map_err(|e| ClientError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))
.ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string()))
}

/// Helper function to expect a response from the stream
async fn expect_response<T>(
async fn expect_response<T, E>(
transport: &mut T,
context: &str,
) -> Result<(ServerResult, RequestId), ClientError>
) -> Result<(ServerResult, RequestId), ClientInitializeError<E>>
where
T: Transport<RoleClient>,
{
let msg = expect_next_message(transport, context).await?;

match msg {
ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)),
_ => Err(ClientError::ExpectedInitResponse(Some(msg))),
_ => Err(ClientInitializeError::ExpectedInitResponse(Some(msg))),
}
}

Expand All @@ -79,7 +83,7 @@ impl ServiceRole for RoleClient {
type PeerNot = ServerNotification;
type Info = ClientInfo;
type PeerInfo = ServerInfo;

type InitializeError<E> = ClientInitializeError<E>;
const IS_CLIENT: bool = true;
}

Expand All @@ -90,7 +94,7 @@ impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
self,
transport: T,
ct: CancellationToken,
) -> impl Future<Output = Result<RunningService<RoleClient, Self>, E>> + Send
) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError<E>>> + Send
where
T: IntoTransport<RoleClient, E, A>,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
Expand All @@ -103,7 +107,7 @@ impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
pub async fn serve_client<S, T, E, A>(
service: S,
transport: T,
) -> Result<RunningService<RoleClient, S>, E>
) -> Result<RunningService<RoleClient, S>, ClientInitializeError<E>>
where
S: Service<RoleClient>,
T: IntoTransport<RoleClient, E, A>,
Expand All @@ -116,7 +120,7 @@ pub async fn serve_client_with_ct<S, T, E, A>(
service: S,
transport: T,
ct: CancellationToken,
) -> Result<RunningService<RoleClient, S>, E>
) -> Result<RunningService<RoleClient, S>, ClientInitializeError<E>>
where
S: Service<RoleClient>,
T: IntoTransport<RoleClient, E, A>,
Expand All @@ -125,14 +129,6 @@ where
let mut transport = transport.into_transport();
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();

// Convert ClientError to std::io::Error, then to E
let handle_client_error = |e: ClientError| -> E {
match e {
ClientError::Io(io_err) => io_err.into(),
other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(),
}
};

// service
let id = id_provider.next_request_id();
let init_request = InitializeRequest {
Expand All @@ -145,23 +141,23 @@ where
ClientRequest::InitializeRequest(init_request),
id.clone(),
))
.await?;

let (response, response_id) = expect_response(&mut transport, "initialize response")
.await
.map_err(handle_client_error)?;
.map_err(|error| ClientInitializeError::TransportError {
error,
context: "send initialize request".into(),
})?;

let (response, response_id) = expect_response(&mut transport, "initialize response").await?;

if id != response_id {
return Err(handle_client_error(ClientError::ConflictInitResponseId(
return Err(ClientInitializeError::ConflictInitResponseId(
id,
response_id,
)));
));
}

let ServerResult::InitializeResult(initialize_result) = response else {
return Err(handle_client_error(ClientError::ExpectedInitResult(Some(
response,
))));
return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
};

// send notification
Expand All @@ -171,9 +167,15 @@ where
extensions: Default::default(),
}),
);
transport.send(notification).await?;
transport
.send(notification)
.await
.map_err(|error| ClientInitializeError::TransportError {
error,
context: "send initialized notification".into(),
})?;
let (peer, peer_rx) = Peer::new(id_provider, initialize_result);
serve_inner(service, transport, peer, peer_rx, ct).await
Ok(serve_inner(service, transport, peer, peer_rx, ct).await)
}

macro_rules! method {
Expand Down
Loading
Loading