1- use futures:: { SinkExt , StreamExt } ;
1+ use futures:: { SinkExt , Stream , StreamExt } ;
2+ use thiserror:: Error ;
23
34use super :: * ;
45use crate :: model:: {
56 CallToolRequest , CallToolRequestParam , CallToolResult , CancelledNotification ,
67 CancelledNotificationParam , ClientInfo , ClientMessage , ClientNotification , ClientRequest ,
78 ClientResult , CompleteRequest , CompleteRequestParam , CompleteResult , GetPromptRequest ,
89 GetPromptRequestParam , GetPromptResult , InitializeRequest , InitializedNotification ,
9- ListPromptsRequest , ListPromptsResult , ListResourceTemplatesRequest ,
10- ListResourceTemplatesResult , ListResourcesRequest , ListResourcesResult , ListToolsRequest ,
11- ListToolsResult , PaginatedRequestParam , PaginatedRequestParamInner , ProgressNotification ,
12- ProgressNotificationParam , ReadResourceRequest , ReadResourceRequestParam , ReadResourceResult ,
10+ JsonRpcMessage , JsonRpcResponse , ListPromptsRequest , ListPromptsResult ,
11+ ListResourceTemplatesRequest , ListResourceTemplatesResult , ListResourcesRequest ,
12+ ListResourcesResult , ListToolsRequest , ListToolsResult , PaginatedRequestParam ,
13+ PaginatedRequestParamInner , ProgressNotification , ProgressNotificationParam ,
14+ ReadResourceRequest , ReadResourceRequestParam , ReadResourceResult , RequestId ,
1315 RootsListChangedNotification , ServerInfo , ServerNotification , ServerRequest , ServerResult ,
1416 SetLevelRequest , SetLevelRequestParam , SubscribeRequest , SubscribeRequestParam ,
1517 UnsubscribeRequest , UnsubscribeRequestParam ,
1618} ;
1719
20+ /// It represents the error that may occur when serving the client.
21+ ///
22+ /// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
23+ #[ derive( Error , Debug ) ]
24+ pub enum ClientError {
25+ #[ error( "expect initialized response, but received: {0:?}" ) ]
26+ ExpectedInitResponse ( Option < JsonRpcMessage < ServerRequest , ServerResult , ServerNotification > > ) ,
27+
28+ #[ error( "expect initialized result, but received: {0:?}" ) ]
29+ ExpectedInitResult ( Option < ServerResult > ) ,
30+
31+ #[ error( "conflict initialized response id: expected {0}, got {1}" ) ]
32+ ConflictInitResponseId ( RequestId , RequestId ) ,
33+
34+ #[ error( "connection closed: {0}" ) ]
35+ ConnectionClosed ( String ) ,
36+
37+ #[ error( "IO error: {0}" ) ]
38+ Io ( #[ from] std:: io:: Error ) ,
39+ }
40+
41+ /// Helper function to get the next message from the stream
42+ async fn expect_next_message < S > (
43+ stream : & mut S ,
44+ context : & str ,
45+ ) -> Result < JsonRpcMessage < ServerRequest , ServerResult , ServerNotification > , ClientError >
46+ where
47+ S : Stream < Item = JsonRpcMessage < ServerRequest , ServerResult , ServerNotification > > + Unpin ,
48+ {
49+ stream
50+ . next ( )
51+ . await
52+ . ok_or_else ( || ClientError :: ConnectionClosed ( context. to_string ( ) ) )
53+ . map_err ( |e| ClientError :: Io ( std:: io:: Error :: new ( std:: io:: ErrorKind :: Other , e) ) )
54+ }
55+
56+ /// Helper function to expect a response from the stream
57+ async fn expect_response < S > (
58+ stream : & mut S ,
59+ context : & str ,
60+ ) -> Result < ( ServerResult , RequestId ) , ClientError >
61+ where
62+ S : Stream < Item = JsonRpcMessage < ServerRequest , ServerResult , ServerNotification > > + Unpin ,
63+ {
64+ let msg = expect_next_message ( stream, context) . await ?;
65+
66+ match msg {
67+ JsonRpcMessage :: Response ( JsonRpcResponse { id, result, .. } ) => Ok ( ( result, id) ) ,
68+ _ => Err ( ClientError :: ExpectedInitResponse ( Some ( msg) ) ) ,
69+ }
70+ }
71+
1872#[ derive( Debug , Clone , Copy , Default , PartialEq , Eq ) ]
1973pub struct RoleClient ;
2074
@@ -74,6 +128,15 @@ where
74128 let mut sink = Box :: pin ( sink) ;
75129 let mut stream = Box :: pin ( stream) ;
76130 let id_provider = <Arc < AtomicU32RequestIdProvider > >:: default ( ) ;
131+
132+ // Convert ClientError to std::io::Error, then to E
133+ let handle_client_error = |e : ClientError | -> E {
134+ match e {
135+ ClientError :: Io ( io_err) => io_err. into ( ) ,
136+ other => std:: io:: Error :: new ( std:: io:: ErrorKind :: Other , format ! ( "{}" , other) ) . into ( ) ,
137+ }
138+ } ;
139+
77140 // service
78141 let id = id_provider. next_request_id ( ) ;
79142 let init_request = InitializeRequest {
@@ -85,34 +148,24 @@ where
85148 . into_json_rpc_message ( ) ,
86149 )
87150 . await ?;
88- let ( response , response_id ) = stream
89- . next ( )
151+
152+ let ( response , response_id ) = expect_response ( & mut stream , "initialize response" )
90153 . await
91- . ok_or ( std:: io:: Error :: new (
92- std:: io:: ErrorKind :: UnexpectedEof ,
93- "expect initialize response" ,
94- ) ) ?
95- . into_message ( )
96- . into_result ( )
97- . ok_or ( std:: io:: Error :: new (
98- std:: io:: ErrorKind :: InvalidData ,
99- "expect initialize result" ,
100- ) ) ?;
154+ . map_err ( handle_client_error) ?;
155+
101156 if id != response_id {
102- return Err ( std:: io:: Error :: new (
103- std:: io:: ErrorKind :: InvalidData ,
104- "conflict initialize response id" ,
105- )
106- . into ( ) ) ;
157+ return Err ( handle_client_error ( ClientError :: ConflictInitResponseId (
158+ id,
159+ response_id,
160+ ) ) ) ;
107161 }
108- let response = response . map_err ( std :: io :: Error :: other ) ? ;
162+
109163 let ServerResult :: InitializeResult ( initialize_result) = response else {
110- return Err ( std:: io:: Error :: new (
111- std:: io:: ErrorKind :: InvalidData ,
112- "expect initialize result" ,
113- )
114- . into ( ) ) ;
164+ return Err ( handle_client_error ( ClientError :: ExpectedInitResult ( Some (
165+ response,
166+ ) ) ) ) ;
115167 } ;
168+
116169 // send notification
117170 let notification = ClientMessage :: Notification ( ClientNotification :: InitializedNotification (
118171 InitializedNotification {
0 commit comments