1- use futures:: { SinkExt , StreamExt } ;
1+ use futures:: { SinkExt , Stream , StreamExt } ;
2+ use thiserror:: Error ;
23
34use super :: * ;
45use crate :: model:: {
56 CallToolRequest , CallToolRequestParam , CallToolResult , CancelledNotification ,
67 CancelledNotificationParam , ClientInfo , ClientMessage , ClientNotification , ClientRequest ,
78 ClientResult , CompleteRequest , CompleteRequestParam , CompleteResult , GetPromptRequest ,
89 GetPromptRequestParam , GetPromptResult , InitializeRequest , InitializedNotification ,
9- ListPromptsRequest , ListPromptsResult , ListResourceTemplatesRequest ,
10+ JsonRpcResponse , ListPromptsRequest , ListPromptsResult , ListResourceTemplatesRequest ,
1011 ListResourceTemplatesResult , ListResourcesRequest , ListResourcesResult , ListToolsRequest ,
1112 ListToolsResult , PaginatedRequestParam , PaginatedRequestParamInner , ProgressNotification ,
1213 ProgressNotificationParam , ReadResourceRequest , ReadResourceRequestParam , ReadResourceResult ,
13- RootsListChangedNotification , ServerInfo , ServerNotification , ServerRequest , ServerResult ,
14- SetLevelRequest , SetLevelRequestParam , SubscribeRequest , SubscribeRequestParam ,
15- UnsubscribeRequest , UnsubscribeRequestParam ,
14+ RequestId , RootsListChangedNotification , ServerInfo , ServerJsonRpcMessage , ServerNotification ,
15+ ServerRequest , ServerResult , SetLevelRequest , SetLevelRequestParam , SubscribeRequest ,
16+ SubscribeRequestParam , UnsubscribeRequest , UnsubscribeRequestParam ,
1617} ;
1718
19+ /// It represents the error that may occur when serving the client.
20+ ///
21+ /// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
22+ #[ derive( Error , Debug ) ]
23+ pub enum ClientError {
24+ #[ error( "expect initialized response, but received: {0:?}" ) ]
25+ ExpectedInitResponse ( Option < ServerJsonRpcMessage > ) ,
26+
27+ #[ error( "expect initialized result, but received: {0:?}" ) ]
28+ ExpectedInitResult ( Option < ServerResult > ) ,
29+
30+ #[ error( "conflict initialized response id: expected {0}, got {1}" ) ]
31+ ConflictInitResponseId ( RequestId , RequestId ) ,
32+
33+ #[ error( "connection closed: {0}" ) ]
34+ ConnectionClosed ( String ) ,
35+
36+ #[ error( "IO error: {0}" ) ]
37+ Io ( #[ from] std:: io:: Error ) ,
38+ }
39+
40+ /// Helper function to get the next message from the stream
41+ async fn expect_next_message < S > (
42+ stream : & mut S ,
43+ context : & str ,
44+ ) -> Result < ServerJsonRpcMessage , ClientError >
45+ where
46+ S : Stream < Item = ServerJsonRpcMessage > + Unpin ,
47+ {
48+ stream
49+ . next ( )
50+ . await
51+ . ok_or_else ( || ClientError :: ConnectionClosed ( context. to_string ( ) ) )
52+ . map_err ( |e| ClientError :: Io ( std:: io:: Error :: new ( std:: io:: ErrorKind :: Other , e) ) )
53+ }
54+
55+ /// Helper function to expect a response from the stream
56+ async fn expect_response < S > (
57+ stream : & mut S ,
58+ context : & str ,
59+ ) -> Result < ( ServerResult , RequestId ) , ClientError >
60+ where
61+ S : Stream < Item = ServerJsonRpcMessage > + Unpin ,
62+ {
63+ let msg = expect_next_message ( stream, context) . await ?;
64+
65+ match msg {
66+ ServerJsonRpcMessage :: Response ( JsonRpcResponse { id, result, .. } ) => Ok ( ( result, id) ) ,
67+ _ => Err ( ClientError :: ExpectedInitResponse ( Some ( msg) ) ) ,
68+ }
69+ }
70+
1871#[ derive( Debug , Clone , Copy , Default , PartialEq , Eq ) ]
1972pub struct RoleClient ;
2073
@@ -74,6 +127,15 @@ where
74127 let mut sink = Box :: pin ( sink) ;
75128 let mut stream = Box :: pin ( stream) ;
76129 let id_provider = <Arc < AtomicU32RequestIdProvider > >:: default ( ) ;
130+
131+ // Convert ClientError to std::io::Error, then to E
132+ let handle_client_error = |e : ClientError | -> E {
133+ match e {
134+ ClientError :: Io ( io_err) => io_err. into ( ) ,
135+ other => std:: io:: Error :: new ( std:: io:: ErrorKind :: Other , format ! ( "{}" , other) ) . into ( ) ,
136+ }
137+ } ;
138+
77139 // service
78140 let id = id_provider. next_request_id ( ) ;
79141 let init_request = InitializeRequest {
@@ -85,34 +147,24 @@ where
85147 . into_json_rpc_message ( ) ,
86148 )
87149 . await ?;
88- let ( response , response_id ) = stream
89- . next ( )
150+
151+ let ( response , response_id ) = expect_response ( & mut stream , "initialize response" )
90152 . await
91- . ok_or ( std:: io:: Error :: new (
92- std:: io:: ErrorKind :: UnexpectedEof ,
93- "expect initialize response" ,
94- ) ) ?
95- . into_message ( )
96- . into_result ( )
97- . ok_or ( std:: io:: Error :: new (
98- std:: io:: ErrorKind :: InvalidData ,
99- "expect initialize result" ,
100- ) ) ?;
153+ . map_err ( handle_client_error) ?;
154+
101155 if id != response_id {
102- return Err ( std:: io:: Error :: new (
103- std:: io:: ErrorKind :: InvalidData ,
104- "conflict initialize response id" ,
105- )
106- . into ( ) ) ;
156+ return Err ( handle_client_error ( ClientError :: ConflictInitResponseId (
157+ id,
158+ response_id,
159+ ) ) ) ;
107160 }
108- let response = response . map_err ( std :: io :: Error :: other ) ? ;
161+
109162 let ServerResult :: InitializeResult ( initialize_result) = response else {
110- return Err ( std:: io:: Error :: new (
111- std:: io:: ErrorKind :: InvalidData ,
112- "expect initialize result" ,
113- )
114- . into ( ) ) ;
163+ return Err ( handle_client_error ( ClientError :: ExpectedInitResult ( Some (
164+ response,
165+ ) ) ) ) ;
115166 } ;
167+
116168 // send notification
117169 let notification = ClientMessage :: Notification ( ClientNotification :: InitializedNotification (
118170 InitializedNotification {
0 commit comments