22use std:: { pin:: Pin , sync:: Arc } ;
33
44use futures:: { StreamExt , future:: BoxFuture } ;
5+ use http:: Uri ;
56use reqwest:: header:: HeaderValue ;
67use sse_stream:: Error as SseError ;
78use thiserror:: Error ;
@@ -32,6 +33,10 @@ pub enum SseTransportError<E: std::error::Error + Send + Sync + 'static> {
3233 #[ cfg_attr( docsrs, doc( cfg( feature = "auth" ) ) ) ]
3334 #[ error( "Auth error: {0}" ) ]
3435 Auth ( #[ from] crate :: transport:: auth:: AuthError ) ,
36+ #[ error( "Invalid uri: {0}" ) ]
37+ InvalidUri ( #[ from] http:: uri:: InvalidUri ) ,
38+ #[ error( "Invalid uri parts: {0}" ) ]
39+ InvalidUriParts ( #[ from] http:: uri:: InvalidUriParts ) ,
3540}
3641
3742impl From < reqwest:: Error > for SseTransportError < reqwest:: Error > {
@@ -44,21 +49,21 @@ pub trait SseClient: Clone + Send + Sync + 'static {
4449 type Error : std:: error:: Error + Send + Sync + ' static ;
4550 fn post_message (
4651 & self ,
47- uri : Arc < str > ,
52+ uri : Uri ,
4853 message : ClientJsonRpcMessage ,
4954 auth_token : Option < String > ,
5055 ) -> impl Future < Output = Result < ( ) , SseTransportError < Self :: Error > > > + Send + ' _ ;
5156 fn get_stream (
5257 & self ,
53- uri : Arc < str > ,
58+ uri : Uri ,
5459 last_event_id : Option < String > ,
5560 auth_token : Option < String > ,
5661 ) -> impl Future < Output = Result < BoxedSseResponse , SseTransportError < Self :: Error > > > + Send + ' _ ;
5762}
5863
5964struct SseClientReconnect < C > {
6065 pub client : C ,
61- pub uri : Arc < str > ,
66+ pub uri : Uri ,
6267}
6368
6469impl < C : SseClient > SseStreamReconnect for SseClientReconnect < C > {
@@ -75,7 +80,7 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
7580pub struct SseClientTransport < C : SseClient > {
7681 client : C ,
7782 config : SseClientConfig ,
78- post_uri : Arc < str > ,
83+ message_endpoint : Uri ,
7984 stream : Option < ServerMessageStream < C > > ,
8085}
8186
@@ -89,7 +94,7 @@ impl<C: SseClient> Transport<RoleClient> for SseClientTransport<C> {
8994 item : crate :: service:: TxJsonRpcMessage < RoleClient > ,
9095 ) -> impl Future < Output = Result < ( ) , Self :: Error > > + Send + ' static {
9196 let client = self . client . clone ( ) ;
92- let uri = self . post_uri . clone ( ) ;
97+ let uri = self . message_endpoint . clone ( ) ;
9398 async move { client. post_message ( uri, item, None ) . await }
9499 }
95100 async fn close ( & mut self ) -> Result < ( ) , Self :: Error > {
@@ -112,9 +117,11 @@ impl<C: SseClient> SseClientTransport<C> {
112117 client : C ,
113118 config : SseClientConfig ,
114119 ) -> Result < Self , SseTransportError < C :: Error > > {
115- let mut sse_stream = client. get_stream ( config. uri . clone ( ) , None , None ) . await ?;
116- let endpoint = if let Some ( endpoint) = config. use_endpoint . clone ( ) {
117- endpoint
120+ let sse_endpoint = config. sse_endpoint . as_ref ( ) . parse :: < http:: Uri > ( ) ?;
121+
122+ let mut sse_stream = client. get_stream ( sse_endpoint. clone ( ) , None , None ) . await ?;
123+ let message_endpoint = if let Some ( endpoint) = config. use_message_endpoint . clone ( ) {
124+ endpoint. parse :: < http:: Uri > ( ) ?
118125 } else {
119126 // wait the endpoint event
120127 loop {
@@ -125,46 +132,59 @@ impl<C: SseClient> SseClientTransport<C> {
125132 let Some ( "endpoint" ) = sse. event . as_deref ( ) else {
126133 continue ;
127134 } ;
128- break sse. data . unwrap_or_default ( ) ;
135+ let sse_endpoint = sse. data . unwrap_or_default ( ) ;
136+ break sse_endpoint. parse :: < http:: Uri > ( ) ?;
129137 }
130138 } ;
131- let post_uri: Arc < str > = format ! (
132- "{}/{}" ,
133- config. uri. trim_end_matches( "/" ) ,
134- endpoint. trim_start_matches( "/" )
135- )
136- . into ( ) ;
139+
140+ // sse: <authority><sse_pq> -> <authority><message_pq>
141+ let message_endpoint = {
142+ let mut sse_endpoint_parts = sse_endpoint. clone ( ) . into_parts ( ) ;
143+ sse_endpoint_parts. path_and_query = message_endpoint. into_parts ( ) . path_and_query ;
144+ Uri :: from_parts ( sse_endpoint_parts) ?
145+ } ;
137146 let stream = Box :: pin ( SseAutoReconnectStream :: new (
138147 sse_stream,
139148 SseClientReconnect {
140149 client : client. clone ( ) ,
141- uri : config . uri . clone ( ) ,
150+ uri : sse_endpoint . clone ( ) ,
142151 } ,
143152 config. retry_policy . clone ( ) ,
144153 ) ) ;
145154 Ok ( Self {
146155 client,
147156 config,
148- post_uri ,
157+ message_endpoint ,
149158 stream : Some ( stream) ,
150159 } )
151160 }
152161}
153162
154163#[ derive( Debug , Clone ) ]
155164pub struct SseClientConfig {
156- pub uri : Arc < str > ,
165+ /// client sse endpoint
166+ ///
167+ /// # How this client resolve the message endpoint
168+ /// if sse_endpoint has this format: `<schema><authority?><sse_pq>`,
169+ /// then the message endpoint will be `<schema><authority?><message_pq>`.
170+ ///
171+ /// For example, if you config the sse_endpoint as `http://example.com/some_path/sse`,
172+ /// and the server send the message endpoint event as `message?session_id=123`,
173+ /// then the message endpoint will be `http://example.com/message`.
174+ ///
175+ /// This follow the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/zh-CN/docs/Web/API/URL/URL)
176+ pub sse_endpoint : Arc < str > ,
157177 pub retry_policy : Arc < dyn SseRetryPolicy > ,
158178 /// if this is settled, the client will use this endpoint to send message and skip get the endpoint event
159- pub use_endpoint : Option < String > ,
179+ pub use_message_endpoint : Option < String > ,
160180}
161181
162182impl Default for SseClientConfig {
163183 fn default ( ) -> Self {
164184 Self {
165- uri : "" . into ( ) ,
185+ sse_endpoint : "" . into ( ) ,
166186 retry_policy : Arc :: new ( super :: common:: client_side_sse:: FixedInterval :: default ( ) ) ,
167- use_endpoint : None ,
187+ use_message_endpoint : None ,
168188 }
169189 }
170190}
0 commit comments