Skip to content

Commit 1dd51ee

Browse files
committed
fix(client): add error enum while deal server info
1. wrap the error type for more standardized 2. add more information in error for debug trace 3. wrap helper func for more user-friendly code Signed-off-by: jokemanfire <[email protected]>
1 parent 588a013 commit 1dd51ee

File tree

1 file changed

+81
-28
lines changed

1 file changed

+81
-28
lines changed

crates/rmcp/src/service/client.rs

Lines changed: 81 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,74 @@
1-
use futures::{SinkExt, StreamExt};
1+
use futures::{SinkExt, Stream, StreamExt};
2+
use thiserror::Error;
23

34
use super::*;
45
use crate::model::{
56
CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification,
67
CancelledNotificationParam, ClientInfo, ClientMessage, ClientNotification, ClientRequest,
78
ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult, GetPromptRequest,
89
GetPromptRequestParam, GetPromptResult, InitializeRequest, InitializedNotification,
9-
ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
10-
ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest,
11-
ListToolsResult, PaginatedRequestParam, PaginatedRequestParamInner, ProgressNotification,
12-
ProgressNotificationParam, ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult,
10+
JsonRpcMessage, JsonRpcResponse, ListPromptsRequest, ListPromptsResult,
11+
ListResourceTemplatesRequest, ListResourceTemplatesResult, ListResourcesRequest,
12+
ListResourcesResult, ListToolsRequest, ListToolsResult, PaginatedRequestParam,
13+
PaginatedRequestParamInner, ProgressNotification, ProgressNotificationParam,
14+
ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult, RequestId,
1315
RootsListChangedNotification, ServerInfo, ServerNotification, ServerRequest, ServerResult,
1416
SetLevelRequest, SetLevelRequestParam, SubscribeRequest, SubscribeRequestParam,
1517
UnsubscribeRequest, UnsubscribeRequestParam,
1618
};
1719

20+
/// It represents the error that may occur when serving the client.
21+
///
22+
/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
23+
#[derive(Error, Debug)]
24+
pub enum ClientError {
25+
#[error("expect initialized response, but received: {0:?}")]
26+
ExpectedInitResponse(Option<JsonRpcMessage<ServerRequest, ServerResult, ServerNotification>>),
27+
28+
#[error("expect initialized result, but received: {0:?}")]
29+
ExpectedInitResult(Option<ServerResult>),
30+
31+
#[error("conflict initialized response id: expected {0}, got {1}")]
32+
ConflictInitResponseId(RequestId, RequestId),
33+
34+
#[error("connection closed: {0}")]
35+
ConnectionClosed(String),
36+
37+
#[error("IO error: {0}")]
38+
Io(#[from] std::io::Error),
39+
}
40+
41+
/// Helper function to get the next message from the stream
42+
async fn expect_next_message<S>(
43+
stream: &mut S,
44+
context: &str,
45+
) -> Result<JsonRpcMessage<ServerRequest, ServerResult, ServerNotification>, ClientError>
46+
where
47+
S: Stream<Item = JsonRpcMessage<ServerRequest, ServerResult, ServerNotification>> + Unpin,
48+
{
49+
stream
50+
.next()
51+
.await
52+
.ok_or_else(|| ClientError::ConnectionClosed(context.to_string()))
53+
.map_err(|e| ClientError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))
54+
}
55+
56+
/// Helper function to expect a response from the stream
57+
async fn expect_response<S>(
58+
stream: &mut S,
59+
context: &str,
60+
) -> Result<(ServerResult, RequestId), ClientError>
61+
where
62+
S: Stream<Item = JsonRpcMessage<ServerRequest, ServerResult, ServerNotification>> + Unpin,
63+
{
64+
let msg = expect_next_message(stream, context).await?;
65+
66+
match msg {
67+
JsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)),
68+
_ => Err(ClientError::ExpectedInitResponse(Some(msg))),
69+
}
70+
}
71+
1872
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
1973
pub struct RoleClient;
2074

@@ -74,6 +128,15 @@ where
74128
let mut sink = Box::pin(sink);
75129
let mut stream = Box::pin(stream);
76130
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
131+
132+
// Convert ClientError to std::io::Error, then to E
133+
let handle_client_error = |e: ClientError| -> E {
134+
match e {
135+
ClientError::Io(io_err) => io_err.into(),
136+
other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(),
137+
}
138+
};
139+
77140
// service
78141
let id = id_provider.next_request_id();
79142
let init_request = InitializeRequest {
@@ -85,34 +148,24 @@ where
85148
.into_json_rpc_message(),
86149
)
87150
.await?;
88-
let (response, response_id) = stream
89-
.next()
151+
152+
let (response, response_id) = expect_response(&mut stream, "initialize response")
90153
.await
91-
.ok_or(std::io::Error::new(
92-
std::io::ErrorKind::UnexpectedEof,
93-
"expect initialize response",
94-
))?
95-
.into_message()
96-
.into_result()
97-
.ok_or(std::io::Error::new(
98-
std::io::ErrorKind::InvalidData,
99-
"expect initialize result",
100-
))?;
154+
.map_err(handle_client_error)?;
155+
101156
if id != response_id {
102-
return Err(std::io::Error::new(
103-
std::io::ErrorKind::InvalidData,
104-
"conflict initialize response id",
105-
)
106-
.into());
157+
return Err(handle_client_error(ClientError::ConflictInitResponseId(
158+
id,
159+
response_id,
160+
)));
107161
}
108-
let response = response.map_err(std::io::Error::other)?;
162+
109163
let ServerResult::InitializeResult(initialize_result) = response else {
110-
return Err(std::io::Error::new(
111-
std::io::ErrorKind::InvalidData,
112-
"expect initialize result",
113-
)
114-
.into());
164+
return Err(handle_client_error(ClientError::ExpectedInitResult(Some(
165+
response,
166+
))));
115167
};
168+
116169
// send notification
117170
let notification = ClientMessage::Notification(ClientNotification::InitializedNotification(
118171
InitializedNotification {

0 commit comments

Comments
 (0)