1- use std:: { collections:: HashMap , net:: SocketAddr , sync:: Arc } ;
1+ use std:: { collections:: HashMap , io , net:: SocketAddr , sync:: Arc , time :: Duration } ;
22
33use axum:: {
44 Json , Router ,
55 extract:: { Query , State } ,
66 http:: StatusCode ,
77 response:: {
88 Response ,
9- sse:: { Event , Sse } ,
9+ sse:: { Event , KeepAlive , Sse } ,
1010 } ,
1111 routing:: { get, post} ,
1212} ;
13- use futures:: { Sink , SinkExt , Stream , StreamExt } ;
14- use tokio:: io;
13+ use futures:: { Sink , SinkExt , Stream } ;
1514use tokio_stream:: wrappers:: ReceiverStream ;
1615use tokio_util:: sync:: { CancellationToken , PollSender } ;
1716use tracing:: Instrument ;
@@ -26,28 +25,33 @@ type TxStore =
2625 Arc < tokio:: sync:: RwLock < HashMap < SessionId , tokio:: sync:: mpsc:: Sender < ClientJsonRpcMessage > > > > ;
2726pub type TransportReceiver = ReceiverStream < RxJsonRpcMessage < RoleServer > > ;
2827
28+ const DEFAULT_AUTO_PING_INTERVAL : Duration = Duration :: from_secs ( 15 ) ;
29+
2930#[ derive( Clone ) ]
3031struct App {
3132 txs : TxStore ,
3233 transport_tx : tokio:: sync:: mpsc:: UnboundedSender < SseServerTransport > ,
3334 post_path : Arc < str > ,
35+ sse_ping_interval : Duration ,
3436}
3537
3638impl App {
3739 pub fn new (
3840 post_path : String ,
41+ sse_ping_interval : Duration ,
3942 ) -> (
4043 Self ,
4144 tokio:: sync:: mpsc:: UnboundedReceiver < SseServerTransport > ,
4245 ) {
43- let ( transport_tx, tranport_rx ) = tokio:: sync:: mpsc:: unbounded_channel ( ) ;
46+ let ( transport_tx, transport_rx ) = tokio:: sync:: mpsc:: unbounded_channel ( ) ;
4447 (
4548 Self {
4649 txs : Default :: default ( ) ,
4750 transport_tx,
4851 post_path : post_path. into ( ) ,
52+ sse_ping_interval,
4953 } ,
50- tranport_rx ,
54+ transport_rx ,
5155 )
5256 }
5357}
@@ -87,7 +91,7 @@ async fn sse_handler(
8791) -> Result < Sse < impl Stream < Item = Result < Event , io:: Error > > > , Response < String > > {
8892 let session = session_id ( ) ;
8993 tracing:: info!( %session, "sse connection" ) ;
90- use tokio_stream:: wrappers:: ReceiverStream ;
94+ use tokio_stream:: { StreamExt , wrappers:: ReceiverStream } ;
9195 use tokio_util:: sync:: PollSender ;
9296 let ( from_client_tx, from_client_rx) = tokio:: sync:: mpsc:: channel ( 64 ) ;
9397 let ( to_client_tx, to_client_rx) = tokio:: sync:: mpsc:: channel ( 64 ) ;
@@ -108,11 +112,12 @@ async fn sse_handler(
108112 if transport_send_result. is_err ( ) {
109113 tracing:: warn!( "send transport out error" ) ;
110114 let mut response =
111- Response :: new ( "fail to send out trasnport , it seems server is closed" . to_string ( ) ) ;
115+ Response :: new ( "fail to send out transport , it seems server is closed" . to_string ( ) ) ;
112116 * response. status_mut ( ) = StatusCode :: INTERNAL_SERVER_ERROR ;
113117 return Err ( response) ;
114118 }
115119 let post_path = app. post_path . as_ref ( ) ;
120+ let ping_interval = app. sse_ping_interval ;
116121 let stream = futures:: stream:: once ( futures:: future:: ok (
117122 Event :: default ( )
118123 . event ( "endpoint" )
@@ -124,7 +129,7 @@ async fn sse_handler(
124129 Err ( e) => Err ( io:: Error :: new ( io:: ErrorKind :: InvalidData , e) ) ,
125130 }
126131 } ) ) ;
127- Ok ( Sse :: new ( stream) )
132+ Ok ( Sse :: new ( stream) . keep_alive ( KeepAlive :: new ( ) . interval ( ping_interval ) ) )
128133}
129134
130135pub struct SseServerTransport {
@@ -190,6 +195,7 @@ impl Stream for SseServerTransport {
190195 mut self : std:: pin:: Pin < & mut Self > ,
191196 cx : & mut std:: task:: Context < ' _ > ,
192197 ) -> std:: task:: Poll < Option < Self :: Item > > {
198+ use futures:: StreamExt ;
193199 self . stream . poll_next_unpin ( cx)
194200 }
195201}
@@ -200,6 +206,7 @@ pub struct SseServerConfig {
200206 pub sse_path : String ,
201207 pub post_path : String ,
202208 pub ct : CancellationToken ,
209+ pub sse_keep_alive : Option < Duration > ,
203210}
204211
205212#[ derive( Debug ) ]
@@ -215,6 +222,7 @@ impl SseServer {
215222 sse_path : "/sse" . to_string ( ) ,
216223 post_path : "/message" . to_string ( ) ,
217224 ct : CancellationToken :: new ( ) ,
225+ sse_keep_alive : None ,
218226 } )
219227 . await
220228 }
@@ -240,7 +248,10 @@ impl SseServer {
240248 /// Warning: This function creates a new SseServer instance with the provided configuration.
241249 /// `App.post_path` may be incorrect if using `Router` as an embedded router.
242250 pub fn new ( config : SseServerConfig ) -> ( SseServer , Router ) {
243- let ( app, transport_rx) = App :: new ( config. post_path . clone ( ) ) ;
251+ let ( app, transport_rx) = App :: new (
252+ config. post_path . clone ( ) ,
253+ config. sse_keep_alive . unwrap_or ( DEFAULT_AUTO_PING_INTERVAL ) ,
254+ ) ;
244255 let router = Router :: new ( )
245256 . route ( & config. sse_path , get ( sse_handler) )
246257 . route ( & config. post_path , post ( post_event_handler) )
0 commit comments