@@ -3,6 +3,7 @@ use std::{fmt::Debug, marker::PhantomData, path::PathBuf};
33use anyhow:: { anyhow, Context } ;
44use pallas_network:: facades:: DmqClient ;
55use slog:: { debug, error, Logger } ;
6+ use tokio:: sync:: { Mutex , MutexGuard } ;
67
78use mithril_common:: {
89 crypto_helper:: { OpCert , TryFromBytes } ,
@@ -19,6 +20,7 @@ use crate::DmqConsumer;
1920pub struct DmqConsumerPallas < M : TryFromBytes + Debug > {
2021 socket : PathBuf ,
2122 network : CardanoNetwork ,
23+ client : Mutex < Option < DmqClient > > ,
2224 logger : Logger ,
2325 phantom : PhantomData < M > ,
2426}
@@ -29,26 +31,72 @@ impl<M: TryFromBytes + Debug> DmqConsumerPallas<M> {
2931 Self {
3032 socket,
3133 network,
34+ client : Mutex :: new ( None ) ,
3235 logger : logger. new_with_component_name :: < Self > ( ) ,
3336 phantom : PhantomData ,
3437 }
3538 }
3639
3740 /// Creates and returns a new `DmqClient` connected to the specified socket.
3841 async fn new_client ( & self ) -> StdResult < DmqClient > {
39- let magic = self . network . code ( ) ;
40- DmqClient :: connect ( & self . socket , magic)
42+ debug ! (
43+ self . logger,
44+ "Create new DMQ client" ;
45+ "socket" => ?self . socket,
46+ "network" => ?self . network
47+ ) ;
48+ DmqClient :: connect ( & self . socket , self . network . code ( ) )
4149 . await
4250 . map_err ( |err| anyhow ! ( err) )
4351 . with_context ( || "PallasChainReader failed to create a new client" )
4452 }
45- }
4653
47- #[ async_trait:: async_trait]
48- impl < M : TryFromBytes + Debug + Sync + Send > DmqConsumer < M > for DmqConsumerPallas < M > {
49- async fn consume_messages ( & self ) -> StdResult < Vec < ( M , PartyId ) > > {
54+ /// Gets the cached `DmqClient`, creating a new one if it does not exist.
55+ async fn get_client ( & self ) -> StdResult < MutexGuard < Option < DmqClient > > > {
56+ {
57+ // Run this in a separate block to avoid dead lock on the Mutex
58+ let client_lock = self . client . lock ( ) . await ;
59+ if client_lock. as_ref ( ) . is_some ( ) {
60+ return Ok ( client_lock) ;
61+ }
62+ }
63+
64+ let mut client_lock = self . client . lock ( ) . await ;
65+ * client_lock = Some ( self . new_client ( ) . await ?) ;
66+
67+ Ok ( client_lock)
68+ }
69+
70+ /// Drops the current `DmqClient`, if it exists.
71+ async fn drop_client ( & self ) -> StdResult < ( ) > {
72+ debug ! (
73+ self . logger,
74+ "Drop exsiting DMQ client" ;
75+ "socket" => ?self . socket,
76+ "network" => ?self . network
77+ ) ;
78+ let mut client_lock = self . client . lock ( ) . await ;
79+ if let Some ( client) = client_lock. take ( ) {
80+ client. abort ( ) . await ;
81+ }
82+
83+ Ok ( ( ) )
84+ }
85+
86+ #[ cfg( test) ]
87+ /// Check if the client already exists (test only).
88+ async fn has_client ( & self ) -> bool {
89+ let client_lock = self . client . lock ( ) . await ;
90+
91+ client_lock. as_ref ( ) . is_some ( )
92+ }
93+
94+ async fn consume_messages_internal ( & self ) -> StdResult < Vec < ( M , PartyId ) > > {
5095 debug ! ( self . logger, "Waiting for messages from DMQ..." ) ;
51- let mut client = self . new_client ( ) . await ?; // TODO: add client cache
96+ let mut client_guard = self . get_client ( ) . await ?;
97+ let client = client_guard
98+ . as_mut ( )
99+ . ok_or ( anyhow ! ( "DMQ client does not exist" ) ) ?;
52100 client
53101 . msg_notification ( )
54102 . send_request_messages_blocking ( )
@@ -78,13 +126,28 @@ impl<M: TryFromBytes + Debug + Sync + Send> DmqConsumer<M> for DmqConsumerPallas
78126 }
79127}
80128
81- #[ cfg( test) ]
129+ #[ async_trait:: async_trait]
130+ impl < M : TryFromBytes + Debug + Sync + Send > DmqConsumer < M > for DmqConsumerPallas < M > {
131+ async fn consume_messages ( & self ) -> StdResult < Vec < ( M , PartyId ) > > {
132+ let messages = self . consume_messages_internal ( ) . await ;
133+ if messages. is_err ( ) {
134+ self . drop_client ( ) . await ?;
135+ }
136+
137+ messages
138+ }
139+ }
140+
141+ #[ cfg( all( test, unix) ) ]
82142mod tests {
83143
84144 use std:: { fs, future, time:: Duration , vec} ;
85145
86146 use mithril_common:: { crypto_helper:: TryToBytes , current_function, test_utils:: TempDir } ;
87- use pallas_network:: miniprotocols:: { localmsgnotification, localmsgsubmission:: DmqMsg } ;
147+ use pallas_network:: {
148+ facades:: DmqServer ,
149+ miniprotocols:: { localmsgnotification, localmsgsubmission:: DmqMsg } ,
150+ } ;
88151 use tokio:: { net:: UnixListener , task:: JoinHandle , time:: sleep} ;
89152
90153 use crate :: { test:: payload:: DmqMessageTestPayload , test_tools:: TestLogger } ;
@@ -135,7 +198,10 @@ mod tests {
135198 ]
136199 }
137200
138- fn setup_dmq_server ( socket_path : PathBuf , reply_messages : Vec < DmqMsg > ) -> JoinHandle < ( ) > {
201+ fn setup_dmq_server (
202+ socket_path : PathBuf ,
203+ reply_messages : Vec < DmqMsg > ,
204+ ) -> JoinHandle < DmqServer > {
139205 tokio:: spawn ( {
140206 async move {
141207 // server setup
@@ -169,6 +235,8 @@ mod tests {
169235 // server waits if no message available
170236 future:: pending ( ) . await
171237 }
238+
239+ server
172240 }
173241 } )
174242 }
@@ -188,7 +256,8 @@ mod tests {
188256 consumer. consume_messages ( ) . await . unwrap ( )
189257 } ) ;
190258
191- let ( _, messages) = tokio:: join!( server, client) ;
259+ let ( _, client_res) = tokio:: join!( server, client) ;
260+ let messages = client_res. unwrap ( ) ;
192261
193262 assert_eq ! (
194263 vec![
@@ -201,7 +270,7 @@ mod tests {
201270 "pool17sln0evyk5tfj6zh2qrlk9vttgy6264sfe2fkec5mheasnlx3yd" . to_string( )
202271 ) ,
203272 ] ,
204- messages. unwrap ( )
273+ messages
205274 ) ;
206275 }
207276
@@ -228,4 +297,44 @@ mod tests {
228297
229298 result. expect_err ( "Should have timed out" ) ;
230299 }
300+
301+ #[ tokio:: test]
302+ async fn pallas_dmq_consumer_client_is_dropped_when_returning_error ( ) {
303+ let socket_path = create_temp_dir ( current_function ! ( ) ) . join ( "node.socket" ) ;
304+ let reply_messages = fake_msgs ( ) ;
305+ let server = setup_dmq_server ( socket_path. clone ( ) , reply_messages) ;
306+ let client = tokio:: spawn ( async move {
307+ let consumer = DmqConsumerPallas :: < DmqMessageTestPayload > :: new (
308+ socket_path,
309+ CardanoNetwork :: TestNet ( 0 ) ,
310+ TestLogger :: stdout ( ) ,
311+ ) ;
312+
313+ consumer. consume_messages ( ) . await . unwrap ( ) ;
314+
315+ consumer
316+ } ) ;
317+
318+ let ( server_res, client_res) = tokio:: join!( server, client) ;
319+ let consumer = client_res. unwrap ( ) ;
320+ let server = server_res. unwrap ( ) ;
321+ server. abort ( ) . await ;
322+
323+ let client = tokio:: spawn ( async move {
324+ assert ! ( consumer. has_client( ) . await , "Client should exist" ) ;
325+
326+ consumer
327+ . consume_messages ( )
328+ . await
329+ . expect_err ( "Consuming messages should fail" ) ;
330+
331+ assert ! (
332+ !consumer. has_client( ) . await ,
333+ "Client should have been dropped after error"
334+ ) ;
335+
336+ consumer
337+ } ) ;
338+ client. await . unwrap ( ) ;
339+ }
231340}
0 commit comments