@@ -5,7 +5,7 @@ use super::*;
55use crate :: model:: {
66 CancelledNotification , CancelledNotificationParam , ClientInfo , ClientJsonRpcMessage ,
77 ClientNotification , ClientRequest , ClientResult , CreateMessageRequest ,
8- CreateMessageRequestParam , CreateMessageResult , ListRootsRequest , ListRootsResult ,
8+ CreateMessageRequestParam , CreateMessageResult , ErrorData , ListRootsRequest , ListRootsResult ,
99 LoggingMessageNotification , LoggingMessageNotificationParam , ProgressNotification ,
1010 ProgressNotificationParam , PromptListChangedNotification , ResourceListChangedNotification ,
1111 ResourceUpdatedNotification , ResourceUpdatedNotificationParam , ServerInfo , ServerNotification ,
@@ -41,6 +41,12 @@ pub enum ServerError {
4141 #[ error( "connection closed: {0}" ) ]
4242 ConnectionClosed ( String ) ,
4343
44+ #[ error( "unexpected initialize result: {0:?}" ) ]
45+ UnexpectedInitializeResponse ( ServerResult ) ,
46+
47+ #[ error( "initialize failed: {0}" ) ]
48+ InitializeFailed ( ErrorData ) ,
49+
4450 #[ error( "IO error: {0}" ) ]
4551 Io ( #[ from] std:: io:: Error ) ,
4652}
@@ -144,14 +150,34 @@ where
144150 . await
145151 . map_err ( handle_server_error) ?;
146152
147- let ClientRequest :: InitializeRequest ( peer_info) = request else {
153+ let ClientRequest :: InitializeRequest ( peer_info) = & request else {
148154 return Err ( handle_server_error ( ServerError :: ExpectedInitRequest ( Some (
149155 ClientJsonRpcMessage :: request ( request, id) ,
150156 ) ) ) ) ;
151157 } ;
152-
158+ let ( peer, peer_rx) = Peer :: new ( id_provider, peer_info. params . clone ( ) ) ;
159+ let context = RequestContext {
160+ ct : ct. child_token ( ) ,
161+ id : id. clone ( ) ,
162+ meta : request. get_meta ( ) . clone ( ) ,
163+ extensions : request. extensions ( ) . clone ( ) ,
164+ peer : peer. clone ( ) ,
165+ } ;
153166 // Send initialize response
154- let mut init_response = service. get_info ( ) ;
167+ let init_response = service. handle_request ( request. clone ( ) , context) . await ;
168+ let mut init_response = match init_response {
169+ Ok ( ServerResult :: InitializeResult ( init_response) ) => init_response,
170+ Ok ( result) => {
171+ return Err ( handle_server_error (
172+ ServerError :: UnexpectedInitializeResponse ( result) ,
173+ ) ) ;
174+ }
175+ Err ( e) => {
176+ sink. send ( ServerJsonRpcMessage :: error ( e. clone ( ) , id) )
177+ . await ?;
178+ return Err ( handle_server_error ( ServerError :: InitializeFailed ( e) ) ) ;
179+ }
180+ } ;
155181 let protocol_version = match peer_info
156182 . params
157183 . protocol_version
@@ -174,15 +200,14 @@ where
174200 let notification = expect_notification ( & mut stream, "initialize notification" )
175201 . await
176202 . map_err ( handle_server_error) ?;
177-
178203 let ClientNotification :: InitializedNotification ( _) = notification else {
179204 return Err ( handle_server_error ( ServerError :: ExpectedInitNotification (
180205 Some ( ClientJsonRpcMessage :: notification ( notification) ) ,
181206 ) ) ) ;
182207 } ;
183-
208+ let _ = service . handle_notification ( notification ) . await ;
184209 // Continue processing service
185- serve_inner ( service, ( sink, stream) , peer_info . params , id_provider , ct) . await
210+ serve_inner ( service, ( sink, stream) , peer , peer_rx , ct) . await
186211}
187212
188213macro_rules! method {
0 commit comments