diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index 134b2ffb..b2e30e0c 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -1,4 +1,5 @@ -use futures::{SinkExt, StreamExt}; +use futures::{SinkExt, Stream, StreamExt}; +use thiserror::Error; use super::*; use crate::model::{ @@ -6,15 +7,67 @@ use crate::model::{ CancelledNotificationParam, ClientInfo, ClientMessage, ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult, GetPromptRequest, GetPromptRequestParam, GetPromptResult, InitializeRequest, InitializedNotification, - ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest, + JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest, ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest, ListToolsResult, PaginatedRequestParam, PaginatedRequestParamInner, ProgressNotification, ProgressNotificationParam, ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult, - RootsListChangedNotification, ServerInfo, ServerNotification, ServerRequest, ServerResult, - SetLevelRequest, SetLevelRequestParam, SubscribeRequest, SubscribeRequestParam, - UnsubscribeRequest, UnsubscribeRequestParam, + RequestId, RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification, + ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParam, SubscribeRequest, + SubscribeRequestParam, UnsubscribeRequest, UnsubscribeRequestParam, }; +/// It represents the error that may occur when serving the client. +/// +/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result, ClientError>` +#[derive(Error, Debug)] +pub enum ClientError { + #[error("expect initialized response, but received: {0:?}")] + ExpectedInitResponse(Option), + + #[error("expect initialized result, but received: {0:?}")] + ExpectedInitResult(Option), + + #[error("conflict initialized response id: expected {0}, got {1}")] + ConflictInitResponseId(RequestId, RequestId), + + #[error("connection closed: {0}")] + ConnectionClosed(String), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + +/// Helper function to get the next message from the stream +async fn expect_next_message( + stream: &mut S, + context: &str, +) -> Result +where + S: Stream + Unpin, +{ + stream + .next() + .await + .ok_or_else(|| ClientError::ConnectionClosed(context.to_string())) + .map_err(|e| ClientError::Io(std::io::Error::new(std::io::ErrorKind::Other, e))) +} + +/// Helper function to expect a response from the stream +async fn expect_response( + stream: &mut S, + context: &str, +) -> Result<(ServerResult, RequestId), ClientError> +where + S: Stream + Unpin, +{ + let msg = expect_next_message(stream, context).await?; + + match msg { + ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)), + _ => Err(ClientError::ExpectedInitResponse(Some(msg))), + } +} + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct RoleClient; @@ -74,6 +127,15 @@ where let mut sink = Box::pin(sink); let mut stream = Box::pin(stream); let id_provider = >::default(); + + // Convert ClientError to std::io::Error, then to E + let handle_client_error = |e: ClientError| -> E { + match e { + ClientError::Io(io_err) => io_err.into(), + other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(), + } + }; + // service let id = id_provider.next_request_id(); let init_request = InitializeRequest { @@ -85,34 +147,24 @@ where .into_json_rpc_message(), ) .await?; - let (response, response_id) = stream - .next() + + let (response, response_id) = expect_response(&mut stream, "initialize response") .await - .ok_or(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "expect initialize response", - ))? - .into_message() - .into_result() - .ok_or(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "expect initialize result", - ))?; + .map_err(handle_client_error)?; + if id != response_id { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "conflict initialize response id", - ) - .into()); + return Err(handle_client_error(ClientError::ConflictInitResponseId( + id, + response_id, + ))); } - let response = response.map_err(std::io::Error::other)?; + let ServerResult::InitializeResult(initialize_result) = response else { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "expect initialize result", - ) - .into()); + return Err(handle_client_error(ClientError::ExpectedInitResult(Some( + response, + )))); }; + // send notification let notification = ClientMessage::Notification(ClientNotification::InitializedNotification( InitializedNotification {