From 8e0d47bfad38d9d30602ec0772357c296f26117e Mon Sep 17 00:00:00 2001 From: jokemanfire Date: Tue, 1 Apr 2025 10:46:55 +0800 Subject: [PATCH] fix(client): add error enum while deal client info 1. wrap the error type for more standardized 2. add more information in error for debug trace 3. wrap helper func for more user-friendly code Signed-off-by: jokemanfire --- crates/rmcp/src/service/client.rs | 108 ++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 28 deletions(-) diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index 134b2ffb..b2e30e0c 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -1,4 +1,5 @@ -use futures::{SinkExt, StreamExt}; +use futures::{SinkExt, Stream, StreamExt}; +use thiserror::Error; use super::*; use crate::model::{ @@ -6,15 +7,67 @@ use crate::model::{ CancelledNotificationParam, ClientInfo, ClientMessage, ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult, GetPromptRequest, GetPromptRequestParam, GetPromptResult, InitializeRequest, InitializedNotification, - ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest, + JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest, ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest, ListToolsResult, PaginatedRequestParam, PaginatedRequestParamInner, ProgressNotification, ProgressNotificationParam, ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult, - RootsListChangedNotification, ServerInfo, ServerNotification, ServerRequest, ServerResult, - SetLevelRequest, SetLevelRequestParam, SubscribeRequest, SubscribeRequestParam, - UnsubscribeRequest, UnsubscribeRequestParam, + RequestId, RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification, + ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParam, SubscribeRequest, + SubscribeRequestParam, UnsubscribeRequest, UnsubscribeRequestParam, }; +/// It represents the error that may occur when serving the client. +/// +/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result, ClientError>` +#[derive(Error, Debug)] +pub enum ClientError { + #[error("expect initialized response, but received: {0:?}")] + ExpectedInitResponse(Option), + + #[error("expect initialized result, but received: {0:?}")] + ExpectedInitResult(Option), + + #[error("conflict initialized response id: expected {0}, got {1}")] + ConflictInitResponseId(RequestId, RequestId), + + #[error("connection closed: {0}")] + ConnectionClosed(String), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + +/// Helper function to get the next message from the stream +async fn expect_next_message( + stream: &mut S, + context: &str, +) -> Result +where + S: Stream + Unpin, +{ + stream + .next() + .await + .ok_or_else(|| ClientError::ConnectionClosed(context.to_string())) + .map_err(|e| ClientError::Io(std::io::Error::new(std::io::ErrorKind::Other, e))) +} + +/// Helper function to expect a response from the stream +async fn expect_response( + stream: &mut S, + context: &str, +) -> Result<(ServerResult, RequestId), ClientError> +where + S: Stream + Unpin, +{ + let msg = expect_next_message(stream, context).await?; + + match msg { + ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)), + _ => Err(ClientError::ExpectedInitResponse(Some(msg))), + } +} + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct RoleClient; @@ -74,6 +127,15 @@ where let mut sink = Box::pin(sink); let mut stream = Box::pin(stream); let id_provider = >::default(); + + // Convert ClientError to std::io::Error, then to E + let handle_client_error = |e: ClientError| -> E { + match e { + ClientError::Io(io_err) => io_err.into(), + other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(), + } + }; + // service let id = id_provider.next_request_id(); let init_request = InitializeRequest { @@ -85,34 +147,24 @@ where .into_json_rpc_message(), ) .await?; - let (response, response_id) = stream - .next() + + let (response, response_id) = expect_response(&mut stream, "initialize response") .await - .ok_or(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "expect initialize response", - ))? - .into_message() - .into_result() - .ok_or(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "expect initialize result", - ))?; + .map_err(handle_client_error)?; + if id != response_id { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "conflict initialize response id", - ) - .into()); + return Err(handle_client_error(ClientError::ConflictInitResponseId( + id, + response_id, + ))); } - let response = response.map_err(std::io::Error::other)?; + let ServerResult::InitializeResult(initialize_result) = response else { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "expect initialize result", - ) - .into()); + return Err(handle_client_error(ClientError::ExpectedInitResult(Some( + response, + )))); }; + // send notification let notification = ClientMessage::Notification(ClientNotification::InitializedNotification( InitializedNotification {