Skip to content
This repository was archived by the owner on Oct 18, 2023. It is now read-only.

Commit b68a3c9

Browse files
committed
fix tests
1 parent 26f781d commit b68a3c9

File tree

15 files changed

+142
-623
lines changed

15 files changed

+142
-623
lines changed

libsqlx-server/src/allocation/mod.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,11 @@ use tokio::time::timeout;
2424
use crate::hrana;
2525
use crate::hrana::http::handle_pipeline;
2626
use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody};
27-
use crate::linc::bus::{Bus, Dispatch};
27+
use crate::linc::bus::{Dispatch};
2828
use crate::linc::proto::{
2929
BuilderStep, Enveloppe, Frames, Message, ProxyResponse, StepError, Value,
3030
};
3131
use crate::linc::{Inbound, NodeId, Outbound};
32-
use crate::manager::Manager;
3332
use crate::meta::DatabaseId;
3433

3534
use self::config::{AllocConfig, DbConfig};
@@ -505,7 +504,7 @@ impl Database {
505504
next_req_id: 0,
506505
primary_id: *primary_id,
507506
database_id: DatabaseId::from_name(&alloc.db_name),
508-
dispatcher: alloc.bus.clone(),
507+
dispatcher: alloc.dispatcher.clone(),
509508
}),
510509
}
511510
}
@@ -687,7 +686,7 @@ pub struct Allocation {
687686

688687
pub hrana_server: Arc<hrana::http::Server>,
689688
/// handle to the message bus
690-
pub bus: Arc<Bus<Arc<Manager>>>,
689+
pub dispatcher: Arc<dyn Dispatch>,
691690
pub db_name: String,
692691
}
693692

@@ -770,7 +769,7 @@ impl Allocation {
770769
next_frame_no,
771770
req_no,
772771
seq_no: 0,
773-
dipatcher: self.bus.clone() as _,
772+
dipatcher: self.dispatcher.clone() as _,
774773
notifier: frame_notifier.clone(),
775774
buffer: Vec::new(),
776775
};
@@ -818,7 +817,7 @@ impl Allocation {
818817
Message::ProxyResponse(ref r) => {
819818
if let Some(conn) = self
820819
.connections
821-
.get(&self.bus.node_id())
820+
.get(&self.dispatcher.node_id())
822821
.and_then(|m| m.get(&r.connection_id).cloned())
823822
{
824823
conn.inbound.send(msg).await.unwrap();
@@ -837,7 +836,7 @@ impl Allocation {
837836
req_id: u32,
838837
program: Program,
839838
) {
840-
let dispatcher = self.bus.clone();
839+
let dispatcher = self.dispatcher.clone();
841840
let database_id = DatabaseId::from_name(&self.db_name);
842841
let exec = |conn: ConnectionHandle| async move {
843842
let _ = conn
@@ -878,7 +877,7 @@ impl Allocation {
878877
let conn = block_in_place(|| self.database.connect(conn_id, self));
879878
let (exec_sender, exec_receiver) = mpsc::channel(1);
880879
let (inbound_sender, inbound_receiver) = mpsc::channel(1);
881-
let id = remote.unwrap_or((self.bus.node_id(), conn_id));
880+
let id = remote.unwrap_or((self.dispatcher.node_id(), conn_id));
882881
let conn = Connection {
883882
id,
884883
conn,
@@ -903,7 +902,7 @@ impl Allocation {
903902
self.next_conn_id = self.next_conn_id.wrapping_add(1);
904903
if self
905904
.connections
906-
.get(&self.bus.node_id())
905+
.get(&self.dispatcher.node_id())
907906
.and_then(|m| m.get(&self.next_conn_id))
908907
.is_none()
909908
{

libsqlx-server/src/linc/bus.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ impl<H: Handler> Bus<H> {
5050
#[async_trait::async_trait]
5151
pub trait Dispatch: Send + Sync + 'static {
5252
async fn dispatch(&self, msg: Outbound);
53+
fn node_id(&self) -> NodeId;
5354
}
5455

5556
#[async_trait::async_trait]
@@ -62,4 +63,8 @@ impl<H: Handler> Dispatch for Bus<H> {
6263
// This message is outbound.
6364
self.send_queue.enqueue(msg).await;
6465
}
66+
67+
fn node_id(&self) -> NodeId {
68+
self.node_id
69+
}
6570
}

libsqlx-server/src/linc/connection.rs

Lines changed: 37 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -273,161 +273,60 @@ where
273273
mod test {
274274
use std::sync::Arc;
275275

276+
use futures::{future, pin_mut};
276277
use tokio::sync::Notify;
277278
use turmoil::net::{TcpListener, TcpStream};
278-
use uuid::Uuid;
279279

280280
use super::*;
281281

282282
#[test]
283283
fn invalid_handshake() {
284284
let mut sim = turmoil::Builder::new().build();
285285

286-
let host_node_id = NodeId::new_v4();
287-
sim.host("host", move || async move {
288-
let bus = Bus::new(host_node_id);
289-
let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234")
290-
.await
291-
.unwrap();
292-
let (s, _) = listener.accept().await.unwrap();
293-
let mut connection = Connection::new_acceptor(s, bus);
294-
connection.tick().await;
295-
296-
Ok(())
286+
let host_node_id = 0;
287+
let done = Arc::new(Notify::new());
288+
let done_clone = done.clone();
289+
sim.host("host", move || {
290+
let done_clone = done_clone.clone();
291+
async move {
292+
let bus = Arc::new(Bus::new(host_node_id, |_, _| async {}));
293+
let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234")
294+
.await
295+
.unwrap();
296+
let (s, _) = listener.accept().await.unwrap();
297+
let connection = Connection::new_acceptor(s, bus);
298+
let done = done_clone.notified();
299+
let run = connection.run();
300+
pin_mut!(done);
301+
pin_mut!(run);
302+
future::select(run, done).await;
303+
304+
Ok(())
305+
}
297306
});
298307

299308
sim.client("client", async move {
300309
let s = TcpStream::connect("host:1234").await.unwrap();
301-
let mut s = AsyncBincodeStream::<_, Message, Message, _>::from(s).for_async();
302-
303-
s.send(Message::Node(NodeMessage::Handshake {
304-
protocol_version: 1234,
305-
node_id: Uuid::new_v4(),
306-
}))
307-
.await
308-
.unwrap();
310+
let mut s = AsyncBincodeStream::<_, Enveloppe, Enveloppe, _>::from(s).for_async();
311+
312+
let msg = Enveloppe {
313+
database_id: None,
314+
message: Message::Handshake {
315+
protocol_version: 1234,
316+
node_id: 1,
317+
},
318+
};
319+
s.send(msg).await.unwrap();
309320
let m = s.next().await.unwrap().unwrap();
310321

311322
assert!(matches!(
312-
m,
313-
Message::Node(NodeMessage::Error(
314-
NodeError::HandshakeVersionMismatch { .. }
315-
))
323+
m.message,
324+
Message::Error(
325+
ProtoError::HandshakeVersionMismatch { .. }
326+
)
316327
));
317328

318-
Ok(())
319-
});
320-
321-
sim.run().unwrap();
322-
}
323-
324-
#[test]
325-
fn stream_closed() {
326-
let mut sim = turmoil::Builder::new().build();
327-
328-
let database_id = DatabaseId::new_v4();
329-
let host_node_id = NodeId::new_v4();
330-
let notify = Arc::new(Notify::new());
331-
sim.host("host", {
332-
let notify = notify.clone();
333-
move || {
334-
let notify = notify.clone();
335-
async move {
336-
let bus = Bus::new(host_node_id);
337-
let mut sub = bus.subscribe(database_id).unwrap();
338-
let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234")
339-
.await
340-
.unwrap();
341-
let (s, _) = listener.accept().await.unwrap();
342-
let connection = Connection::new_acceptor(s, bus);
343-
tokio::task::spawn_local(connection.run());
344-
let mut streams = Vec::new();
345-
loop {
346-
tokio::select! {
347-
Some(mut stream) = sub.next() => {
348-
let m = stream.next().await.unwrap();
349-
stream.send(m).await.unwrap();
350-
streams.push(stream);
351-
}
352-
_ = notify.notified() => {
353-
break;
354-
}
355-
}
356-
}
357-
358-
Ok(())
359-
}
360-
}
361-
});
362-
363-
sim.client("client", async move {
364-
let stream_id = StreamId::new(1);
365-
let node_id = NodeId::new_v4();
366-
let s = TcpStream::connect("host:1234").await.unwrap();
367-
let mut s = AsyncBincodeStream::<_, Message, Message, _>::from(s).for_async();
368-
369-
s.send(Message::Node(NodeMessage::Handshake {
370-
protocol_version: CURRENT_PROTO_VERSION,
371-
node_id,
372-
}))
373-
.await
374-
.unwrap();
375-
let m = s.next().await.unwrap().unwrap();
376-
assert!(matches!(m, Message::Node(NodeMessage::Handshake { .. })));
377-
378-
// send message to unexisting stream:
379-
s.send(Message::Stream {
380-
stream_id,
381-
payload: StreamMessage::Dummy,
382-
})
383-
.await
384-
.unwrap();
385-
let m = s.next().await.unwrap().unwrap();
386-
assert_eq!(
387-
m,
388-
Message::Node(NodeMessage::Error(NodeError::UnknownStream(stream_id)))
389-
);
390-
391-
// open stream then send message
392-
s.send(Message::Node(NodeMessage::OpenStream {
393-
stream_id,
394-
database_id,
395-
}))
396-
.await
397-
.unwrap();
398-
s.send(Message::Stream {
399-
stream_id,
400-
payload: StreamMessage::Dummy,
401-
})
402-
.await
403-
.unwrap();
404-
let m = s.next().await.unwrap().unwrap();
405-
assert_eq!(
406-
m,
407-
Message::Stream {
408-
stream_id,
409-
payload: StreamMessage::Dummy
410-
}
411-
);
412-
413-
s.send(Message::Node(NodeMessage::CloseStream {
414-
stream_id: StreamId::new(1),
415-
}))
416-
.await
417-
.unwrap();
418-
s.send(Message::Stream {
419-
stream_id,
420-
payload: StreamMessage::Dummy,
421-
})
422-
.await
423-
.unwrap();
424-
let m = s.next().await.unwrap().unwrap();
425-
assert_eq!(
426-
m,
427-
Message::Node(NodeMessage::Error(NodeError::UnknownStream(stream_id)))
428-
);
429-
430-
notify.notify_waiters();
329+
done.notify_waiters();
431330

432331
Ok(())
433332
});
@@ -459,7 +358,7 @@ mod test {
459358

460359
sim.client("client", async move {
461360
let stream = TcpStream::connect("host:1234").await.unwrap();
462-
let bus = Bus::new(NodeId::new_v4());
361+
let bus = Arc::new(Bus::new(1, |_, _| async {}));
463362
let mut conn = Connection::new_acceptor(stream, bus);
464363

465364
notify.notify_waiters();
@@ -473,57 +372,4 @@ mod test {
473372

474373
sim.run().unwrap();
475374
}
476-
477-
#[test]
478-
fn zero_stream_id() {
479-
let mut sim = turmoil::Builder::new().build();
480-
481-
let notify = Arc::new(Notify::new());
482-
sim.host("host", {
483-
let notify = notify.clone();
484-
move || {
485-
let notify = notify.clone();
486-
async move {
487-
let listener = TcpListener::bind("0.0.0.0:1234").await.unwrap();
488-
let (stream, _) = listener.accept().await.unwrap();
489-
let (connection_messages_sender, connection_messages) = mpsc::channel(1);
490-
let conn = Connection {
491-
peer: Some(NodeId::new_v4()),
492-
state: ConnectionState::Connected,
493-
conn: AsyncBincodeStream::from(stream).for_async(),
494-
streams: HashMap::new(),
495-
connection_messages,
496-
connection_messages_sender,
497-
is_initiator: false,
498-
bus: Bus::new(NodeId::new_v4()),
499-
stream_id_allocator: StreamIdAllocator::new(false),
500-
registration: None,
501-
};
502-
503-
conn.run().await;
504-
505-
Ok(())
506-
}
507-
}
508-
});
509-
510-
sim.client("client", async move {
511-
let stream = TcpStream::connect("host:1234").await.unwrap();
512-
let mut stream = AsyncBincodeStream::<_, Message, Message, _>::from(stream).for_async();
513-
514-
stream
515-
.send(Message::Stream {
516-
stream_id: StreamId::new_unchecked(0),
517-
payload: StreamMessage::Dummy,
518-
})
519-
.await
520-
.unwrap();
521-
522-
assert!(stream.next().await.is_none());
523-
524-
Ok(())
525-
});
526-
527-
sim.run().unwrap();
528-
}
529375
}

0 commit comments

Comments
 (0)