Skip to content

Commit 8e0d47b

Browse files
committed
fix(client): add error enum while deal client info
1. wrap the error type for more standardized 2. add more information in error for debug trace 3. wrap helper func for more user-friendly code Signed-off-by: jokemanfire <[email protected]>
1 parent 588a013 commit 8e0d47b

File tree

1 file changed

+80
-28
lines changed

1 file changed

+80
-28
lines changed

crates/rmcp/src/service/client.rs

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,73 @@
1-
use futures::{SinkExt, StreamExt};
1+
use futures::{SinkExt, Stream, StreamExt};
2+
use thiserror::Error;
23

34
use super::*;
45
use crate::model::{
56
CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification,
67
CancelledNotificationParam, ClientInfo, ClientMessage, ClientNotification, ClientRequest,
78
ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult, GetPromptRequest,
89
GetPromptRequestParam, GetPromptResult, InitializeRequest, InitializedNotification,
9-
ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
10+
JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
1011
ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest,
1112
ListToolsResult, PaginatedRequestParam, PaginatedRequestParamInner, ProgressNotification,
1213
ProgressNotificationParam, ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult,
13-
RootsListChangedNotification, ServerInfo, ServerNotification, ServerRequest, ServerResult,
14-
SetLevelRequest, SetLevelRequestParam, SubscribeRequest, SubscribeRequestParam,
15-
UnsubscribeRequest, UnsubscribeRequestParam,
14+
RequestId, RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification,
15+
ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParam, SubscribeRequest,
16+
SubscribeRequestParam, UnsubscribeRequest, UnsubscribeRequestParam,
1617
};
1718

19+
/// It represents the error that may occur when serving the client.
20+
///
21+
/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
22+
#[derive(Error, Debug)]
23+
pub enum ClientError {
24+
#[error("expect initialized response, but received: {0:?}")]
25+
ExpectedInitResponse(Option<ServerJsonRpcMessage>),
26+
27+
#[error("expect initialized result, but received: {0:?}")]
28+
ExpectedInitResult(Option<ServerResult>),
29+
30+
#[error("conflict initialized response id: expected {0}, got {1}")]
31+
ConflictInitResponseId(RequestId, RequestId),
32+
33+
#[error("connection closed: {0}")]
34+
ConnectionClosed(String),
35+
36+
#[error("IO error: {0}")]
37+
Io(#[from] std::io::Error),
38+
}
39+
40+
/// Helper function to get the next message from the stream
41+
async fn expect_next_message<S>(
42+
stream: &mut S,
43+
context: &str,
44+
) -> Result<ServerJsonRpcMessage, ClientError>
45+
where
46+
S: Stream<Item = ServerJsonRpcMessage> + Unpin,
47+
{
48+
stream
49+
.next()
50+
.await
51+
.ok_or_else(|| ClientError::ConnectionClosed(context.to_string()))
52+
.map_err(|e| ClientError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))
53+
}
54+
55+
/// Helper function to expect a response from the stream
56+
async fn expect_response<S>(
57+
stream: &mut S,
58+
context: &str,
59+
) -> Result<(ServerResult, RequestId), ClientError>
60+
where
61+
S: Stream<Item = ServerJsonRpcMessage> + Unpin,
62+
{
63+
let msg = expect_next_message(stream, context).await?;
64+
65+
match msg {
66+
ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)),
67+
_ => Err(ClientError::ExpectedInitResponse(Some(msg))),
68+
}
69+
}
70+
1871
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
1972
pub struct RoleClient;
2073

@@ -74,6 +127,15 @@ where
74127
let mut sink = Box::pin(sink);
75128
let mut stream = Box::pin(stream);
76129
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
130+
131+
// Convert ClientError to std::io::Error, then to E
132+
let handle_client_error = |e: ClientError| -> E {
133+
match e {
134+
ClientError::Io(io_err) => io_err.into(),
135+
other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(),
136+
}
137+
};
138+
77139
// service
78140
let id = id_provider.next_request_id();
79141
let init_request = InitializeRequest {
@@ -85,34 +147,24 @@ where
85147
.into_json_rpc_message(),
86148
)
87149
.await?;
88-
let (response, response_id) = stream
89-
.next()
150+
151+
let (response, response_id) = expect_response(&mut stream, "initialize response")
90152
.await
91-
.ok_or(std::io::Error::new(
92-
std::io::ErrorKind::UnexpectedEof,
93-
"expect initialize response",
94-
))?
95-
.into_message()
96-
.into_result()
97-
.ok_or(std::io::Error::new(
98-
std::io::ErrorKind::InvalidData,
99-
"expect initialize result",
100-
))?;
153+
.map_err(handle_client_error)?;
154+
101155
if id != response_id {
102-
return Err(std::io::Error::new(
103-
std::io::ErrorKind::InvalidData,
104-
"conflict initialize response id",
105-
)
106-
.into());
156+
return Err(handle_client_error(ClientError::ConflictInitResponseId(
157+
id,
158+
response_id,
159+
)));
107160
}
108-
let response = response.map_err(std::io::Error::other)?;
161+
109162
let ServerResult::InitializeResult(initialize_result) = response else {
110-
return Err(std::io::Error::new(
111-
std::io::ErrorKind::InvalidData,
112-
"expect initialize result",
113-
)
114-
.into());
163+
return Err(handle_client_error(ClientError::ExpectedInitResult(Some(
164+
response,
165+
))));
115166
};
167+
116168
// send notification
117169
let notification = ClientMessage::Notification(ClientNotification::InitializedNotification(
118170
InitializedNotification {

0 commit comments

Comments
 (0)