11use futures:: { SinkExt , StreamExt } ;
2+ use thiserror:: Error ;
23
34use super :: * ;
45use crate :: model:: {
5- CancelledNotification , CancelledNotificationParam , ClientInfo , ClientNotification ,
6- ClientRequest , ClientResult , CreateMessageRequest , CreateMessageRequestParam ,
7- CreateMessageResult , ListRootsRequest , ListRootsResult , LoggingMessageNotification ,
8- LoggingMessageNotificationParam , ProgressNotification , ProgressNotificationParam ,
9- PromptListChangedNotification , ResourceListChangedNotification , ResourceUpdatedNotification ,
10- ResourceUpdatedNotificationParam , ServerInfo , ServerMessage , ServerNotification , ServerRequest ,
11- ServerResult , ToolListChangedNotification ,
6+ CancelledNotification , CancelledNotificationParam , ClientInfo , ClientJsonRpcMessage ,
7+ ClientMessage , ClientNotification , ClientRequest , ClientResult , CreateMessageRequest ,
8+ CreateMessageRequestParam , CreateMessageResult , ListRootsRequest , ListRootsResult ,
9+ LoggingMessageNotification , LoggingMessageNotificationParam , ProgressNotification ,
10+ ProgressNotificationParam , PromptListChangedNotification , ResourceListChangedNotification ,
11+ ResourceUpdatedNotification , ResourceUpdatedNotificationParam , ServerInfo , ServerMessage ,
12+ ServerNotification , ServerRequest , ServerResult , ToolListChangedNotification ,
1213} ;
1314
1415#[ derive( Debug , Clone , Copy , Default , PartialEq , Eq ) ]
@@ -26,6 +27,24 @@ impl ServiceRole for RoleServer {
2627 const IS_CLIENT : bool = false ;
2728}
2829
30+ /// It represents the error that may occur when serving the server.
31+ ///
32+ /// if you want to handle the error, you can use `serve_server_with_ct` or `serve_server` with `Result<RunningService<RoleServer, S>, ServerError>`
33+ #[ derive( Error , Debug ) ]
34+ pub enum ServerError {
35+ #[ error( "expect initialized request, but received: {0:?}" ) ]
36+ ExpectedInitRequest ( Option < ClientMessage > ) ,
37+
38+ #[ error( "expect initialized notification, but received: {0:?}" ) ]
39+ ExpectedInitNotification ( Option < ClientMessage > ) ,
40+
41+ #[ error( "connection closed: {0}" ) ]
42+ ConnectionClosed ( String ) ,
43+
44+ #[ error( "IO error: {0}" ) ]
45+ Io ( #[ from] std:: io:: Error ) ,
46+ }
47+
2948pub type ClientSink = Peer < RoleServer > ;
3049
3150impl < S : Service < RoleServer > > ServiceExt < RoleServer > for S {
5574 serve_server_with_ct ( service, transport, CancellationToken :: new ( ) ) . await
5675}
5776
77+ /// Helper function to get the next message from the stream
78+ async fn expect_next_message < S > ( stream : & mut S , context : & str ) -> Result < ClientMessage , ServerError >
79+ where
80+ S : StreamExt < Item = ClientJsonRpcMessage > + Unpin ,
81+ {
82+ Ok ( stream
83+ . next ( )
84+ . await
85+ . ok_or_else ( || ServerError :: ConnectionClosed ( context. to_string ( ) ) ) ?
86+ . into_message ( ) )
87+ }
88+
89+ /// Helper function to expect a request from the stream
90+ async fn expect_request < S > (
91+ stream : & mut S ,
92+ context : & str ,
93+ ) -> Result < ( ClientRequest , RequestId ) , ServerError >
94+ where
95+ S : StreamExt < Item = ClientJsonRpcMessage > + Unpin ,
96+ {
97+ let msg = expect_next_message ( stream, context) . await ?;
98+ let msg_clone = msg. clone ( ) ;
99+ msg. into_request ( )
100+ . ok_or ( ServerError :: ExpectedInitRequest ( Some ( msg_clone) ) )
101+ }
102+
103+ /// Helper function to expect a notification from the stream
104+ async fn expect_notification < S > (
105+ stream : & mut S ,
106+ context : & str ,
107+ ) -> Result < ClientNotification , ServerError >
108+ where
109+ S : StreamExt < Item = ClientJsonRpcMessage > + Unpin ,
110+ {
111+ let msg = expect_next_message ( stream, context) . await ?;
112+ let msg_clone = msg. clone ( ) ;
113+ msg. into_notification ( )
114+ . ok_or ( ServerError :: ExpectedInitNotification ( Some ( msg_clone) ) )
115+ }
116+
58117pub async fn serve_server_with_ct < S , T , E , A > (
59118 service : S ,
60119 transport : T ,
@@ -70,54 +129,45 @@ where
70129 let mut stream = Box :: pin ( stream) ;
71130 let id_provider = <Arc < AtomicU32RequestIdProvider > >:: default ( ) ;
72131
73- // service
74- let ( request, id) = stream
75- . next ( )
132+ // Convert ServerError to std::io::Error, then to E
133+ let handle_server_error = |e : ServerError | -> E {
134+ match e {
135+ ServerError :: Io ( io_err) => io_err. into ( ) ,
136+ other => std:: io:: Error :: new ( std:: io:: ErrorKind :: Other , format ! ( "{}" , other) ) . into ( ) ,
137+ }
138+ } ;
139+
140+ // Get initialize request
141+ let ( request, id) = expect_request ( & mut stream, "initialized request" )
76142 . await
77- . ok_or ( std:: io:: Error :: new (
78- std:: io:: ErrorKind :: UnexpectedEof ,
79- "expect initialize request" ,
80- ) ) ?
81- . into_message ( )
82- . into_request ( )
83- . ok_or ( std:: io:: Error :: new (
84- std:: io:: ErrorKind :: InvalidData ,
85- "expect initialize request" ,
86- ) ) ?;
143+ . map_err ( handle_server_error) ?;
144+
87145 let ClientRequest :: InitializeRequest ( peer_info) = request else {
88- return Err ( std:: io:: Error :: new (
89- std:: io:: ErrorKind :: InvalidData ,
90- "expect initialize request" ,
91- )
92- . into ( ) ) ;
146+ return Err ( handle_server_error ( ServerError :: ExpectedInitRequest ( Some (
147+ ClientMessage :: Request ( request, id) ,
148+ ) ) ) ) ;
93149 } ;
150+
151+ // Send initialize response
94152 let init_response = service. get_info ( ) ;
95153 sink. send (
96154 ServerMessage :: Response ( ServerResult :: InitializeResult ( init_response) , id)
97155 . into_json_rpc_message ( ) ,
98156 )
99157 . await ?;
100- // waiting for notification
101- let notification = stream
102- . next ( )
158+
159+ // Wait for initialize notification
160+ let notification = expect_notification ( & mut stream , "initialize notification" )
103161 . await
104- . ok_or ( std:: io:: Error :: new (
105- std:: io:: ErrorKind :: UnexpectedEof ,
106- "expect initialize notification" ,
107- ) ) ?
108- . into_message ( )
109- . into_notification ( )
110- . ok_or ( std:: io:: Error :: new (
111- std:: io:: ErrorKind :: InvalidData ,
112- "expect initialize notification" ,
113- ) ) ?;
162+ . map_err ( handle_server_error) ?;
163+
114164 let ClientNotification :: InitializedNotification ( _) = notification else {
115- return Err ( std:: io:: Error :: new (
116- std:: io:: ErrorKind :: InvalidData ,
117- "expect initialize notification" ,
118- )
119- . into ( ) ) ;
165+ return Err ( handle_server_error ( ServerError :: ExpectedInitNotification (
166+ Some ( ClientMessage :: Notification ( notification) ) ,
167+ ) ) ) ;
120168 } ;
169+
170+ // Continue processing service
121171 serve_inner ( service, ( sink, stream) , peer_info. params , id_provider, ct) . await
122172}
123173
0 commit comments