Skip to content

Commit 12ea76a

Browse files
committed
feat(dmq): cache 'DmqClient' in 'DmqConsumer'
1 parent 8eb2bb8 commit 12ea76a

File tree

2 files changed

+123
-14
lines changed

2 files changed

+123
-14
lines changed

internal/mithril-dmq-node/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ mithril-cardano-node-chain = { path = "../cardano-node/mithril-cardano-node-chai
1717
mithril-common = { path = "../../mithril-common" }
1818
pallas-network = { git = "https://github.com/txpipe/pallas.git", branch = "main" }
1919
slog = { workspace = true }
20-
tokio = { workspace = true }
20+
tokio = { workspace = true, features = ["sync"] }
2121

2222
[dev-dependencies]
2323
mithril-common = { path = "../../mithril-common", features = ["test_tools"] }
2424
mockall = { workspace = true }
2525
slog-async = { workspace = true }
26-
slog-term = { workspace = true }
26+
slog-term = { workspace = true }

internal/mithril-dmq-node/src/consumer/pallas.rs

Lines changed: 121 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::{fmt::Debug, marker::PhantomData, path::PathBuf};
33
use anyhow::{anyhow, Context};
44
use pallas_network::facades::DmqClient;
55
use slog::{debug, error, Logger};
6+
use tokio::sync::{Mutex, MutexGuard};
67

78
use mithril_common::{
89
crypto_helper::{OpCert, TryFromBytes},
@@ -19,6 +20,7 @@ use crate::DmqConsumer;
1920
pub struct DmqConsumerPallas<M: TryFromBytes + Debug> {
2021
socket: PathBuf,
2122
network: CardanoNetwork,
23+
client: Mutex<Option<DmqClient>>,
2224
logger: Logger,
2325
phantom: PhantomData<M>,
2426
}
@@ -29,26 +31,72 @@ impl<M: TryFromBytes + Debug> DmqConsumerPallas<M> {
2931
Self {
3032
socket,
3133
network,
34+
client: Mutex::new(None),
3235
logger: logger.new_with_component_name::<Self>(),
3336
phantom: PhantomData,
3437
}
3538
}
3639

3740
/// Creates and returns a new `DmqClient` connected to the specified socket.
3841
async fn new_client(&self) -> StdResult<DmqClient> {
39-
let magic = self.network.code();
40-
DmqClient::connect(&self.socket, magic)
42+
debug!(
43+
self.logger,
44+
"Create new DMQ client";
45+
"socket" => ?self.socket,
46+
"network" => ?self.network
47+
);
48+
DmqClient::connect(&self.socket, self.network.code())
4149
.await
4250
.map_err(|err| anyhow!(err))
4351
.with_context(|| "PallasChainReader failed to create a new client")
4452
}
45-
}
4653

47-
#[async_trait::async_trait]
48-
impl<M: TryFromBytes + Debug + Sync + Send> DmqConsumer<M> for DmqConsumerPallas<M> {
49-
async fn consume_messages(&self) -> StdResult<Vec<(M, PartyId)>> {
54+
/// Gets the cached `DmqClient`, creating a new one if it does not exist.
55+
async fn get_client(&self) -> StdResult<MutexGuard<Option<DmqClient>>> {
56+
{
57+
// Run this in a separate block to avoid dead lock on the Mutex
58+
let client_lock = self.client.lock().await;
59+
if client_lock.as_ref().is_some() {
60+
return Ok(client_lock);
61+
}
62+
}
63+
64+
let mut client_lock = self.client.lock().await;
65+
*client_lock = Some(self.new_client().await?);
66+
67+
Ok(client_lock)
68+
}
69+
70+
/// Drops the current `DmqClient`, if it exists.
71+
async fn drop_client(&self) -> StdResult<()> {
72+
debug!(
73+
self.logger,
74+
"Drop exsiting DMQ client";
75+
"socket" => ?self.socket,
76+
"network" => ?self.network
77+
);
78+
let mut client_lock = self.client.lock().await;
79+
if let Some(client) = client_lock.take() {
80+
client.abort().await;
81+
}
82+
83+
Ok(())
84+
}
85+
86+
#[cfg(test)]
87+
/// Check if the client already exists (test only).
88+
async fn has_client(&self) -> bool {
89+
let client_lock = self.client.lock().await;
90+
91+
client_lock.as_ref().is_some()
92+
}
93+
94+
async fn consume_messages_internal(&self) -> StdResult<Vec<(M, PartyId)>> {
5095
debug!(self.logger, "Waiting for messages from DMQ...");
51-
let mut client = self.new_client().await?; // TODO: add client cache
96+
let mut client_guard = self.get_client().await?;
97+
let client = client_guard
98+
.as_mut()
99+
.ok_or(anyhow!("DMQ client does not exist"))?;
52100
client
53101
.msg_notification()
54102
.send_request_messages_blocking()
@@ -78,13 +126,28 @@ impl<M: TryFromBytes + Debug + Sync + Send> DmqConsumer<M> for DmqConsumerPallas
78126
}
79127
}
80128

81-
#[cfg(test)]
129+
#[async_trait::async_trait]
130+
impl<M: TryFromBytes + Debug + Sync + Send> DmqConsumer<M> for DmqConsumerPallas<M> {
131+
async fn consume_messages(&self) -> StdResult<Vec<(M, PartyId)>> {
132+
let messages = self.consume_messages_internal().await;
133+
if messages.is_err() {
134+
self.drop_client().await?;
135+
}
136+
137+
messages
138+
}
139+
}
140+
141+
#[cfg(all(test, unix))]
82142
mod tests {
83143

84144
use std::{fs, future, time::Duration, vec};
85145

86146
use mithril_common::{crypto_helper::TryToBytes, current_function, test_utils::TempDir};
87-
use pallas_network::miniprotocols::{localmsgnotification, localmsgsubmission::DmqMsg};
147+
use pallas_network::{
148+
facades::DmqServer,
149+
miniprotocols::{localmsgnotification, localmsgsubmission::DmqMsg},
150+
};
88151
use tokio::{net::UnixListener, task::JoinHandle, time::sleep};
89152

90153
use crate::{test::payload::DmqMessageTestPayload, test_tools::TestLogger};
@@ -135,7 +198,10 @@ mod tests {
135198
]
136199
}
137200

138-
fn setup_dmq_server(socket_path: PathBuf, reply_messages: Vec<DmqMsg>) -> JoinHandle<()> {
201+
fn setup_dmq_server(
202+
socket_path: PathBuf,
203+
reply_messages: Vec<DmqMsg>,
204+
) -> JoinHandle<DmqServer> {
139205
tokio::spawn({
140206
async move {
141207
// server setup
@@ -169,6 +235,8 @@ mod tests {
169235
// server waits if no message available
170236
future::pending().await
171237
}
238+
239+
server
172240
}
173241
})
174242
}
@@ -188,7 +256,8 @@ mod tests {
188256
consumer.consume_messages().await.unwrap()
189257
});
190258

191-
let (_, messages) = tokio::join!(server, client);
259+
let (_, client_res) = tokio::join!(server, client);
260+
let messages = client_res.unwrap();
192261

193262
assert_eq!(
194263
vec![
@@ -201,7 +270,7 @@ mod tests {
201270
"pool17sln0evyk5tfj6zh2qrlk9vttgy6264sfe2fkec5mheasnlx3yd".to_string()
202271
),
203272
],
204-
messages.unwrap()
273+
messages
205274
);
206275
}
207276

@@ -228,4 +297,44 @@ mod tests {
228297

229298
result.expect_err("Should have timed out");
230299
}
300+
301+
#[tokio::test]
302+
async fn pallas_dmq_consumer_client_is_dropped_when_returning_error() {
303+
let socket_path = create_temp_dir(current_function!()).join("node.socket");
304+
let reply_messages = fake_msgs();
305+
let server = setup_dmq_server(socket_path.clone(), reply_messages);
306+
let client = tokio::spawn(async move {
307+
let consumer = DmqConsumerPallas::<DmqMessageTestPayload>::new(
308+
socket_path,
309+
CardanoNetwork::TestNet(0),
310+
TestLogger::stdout(),
311+
);
312+
313+
consumer.consume_messages().await.unwrap();
314+
315+
consumer
316+
});
317+
318+
let (server_res, client_res) = tokio::join!(server, client);
319+
let consumer = client_res.unwrap();
320+
let server = server_res.unwrap();
321+
server.abort().await;
322+
323+
let client = tokio::spawn(async move {
324+
assert!(consumer.has_client().await, "Client should exist");
325+
326+
consumer
327+
.consume_messages()
328+
.await
329+
.expect_err("Consuming messages should fail");
330+
331+
assert!(
332+
!consumer.has_client().await,
333+
"Client should have been dropped after error"
334+
);
335+
336+
consumer
337+
});
338+
client.await.unwrap();
339+
}
231340
}

0 commit comments

Comments
 (0)