@@ -273,161 +273,60 @@ where
273
273
mod test {
274
274
use std:: sync:: Arc ;
275
275
276
+ use futures:: { future, pin_mut} ;
276
277
use tokio:: sync:: Notify ;
277
278
use turmoil:: net:: { TcpListener , TcpStream } ;
278
- use uuid:: Uuid ;
279
279
280
280
use super :: * ;
281
281
282
282
#[ test]
283
283
fn invalid_handshake ( ) {
284
284
let mut sim = turmoil:: Builder :: new ( ) . build ( ) ;
285
285
286
- let host_node_id = NodeId :: new_v4 ( ) ;
287
- sim. host ( "host" , move || async move {
288
- let bus = Bus :: new ( host_node_id) ;
289
- let listener = turmoil:: net:: TcpListener :: bind ( "0.0.0.0:1234" )
290
- . await
291
- . unwrap ( ) ;
292
- let ( s, _) = listener. accept ( ) . await . unwrap ( ) ;
293
- let mut connection = Connection :: new_acceptor ( s, bus) ;
294
- connection. tick ( ) . await ;
295
-
296
- Ok ( ( ) )
286
+ let host_node_id = 0 ;
287
+ let done = Arc :: new ( Notify :: new ( ) ) ;
288
+ let done_clone = done. clone ( ) ;
289
+ sim. host ( "host" , move || {
290
+ let done_clone = done_clone. clone ( ) ;
291
+ async move {
292
+ let bus = Arc :: new ( Bus :: new ( host_node_id, |_, _| async { } ) ) ;
293
+ let listener = turmoil:: net:: TcpListener :: bind ( "0.0.0.0:1234" )
294
+ . await
295
+ . unwrap ( ) ;
296
+ let ( s, _) = listener. accept ( ) . await . unwrap ( ) ;
297
+ let connection = Connection :: new_acceptor ( s, bus) ;
298
+ let done = done_clone. notified ( ) ;
299
+ let run = connection. run ( ) ;
300
+ pin_mut ! ( done) ;
301
+ pin_mut ! ( run) ;
302
+ future:: select ( run, done) . await ;
303
+
304
+ Ok ( ( ) )
305
+ }
297
306
} ) ;
298
307
299
308
sim. client ( "client" , async move {
300
309
let s = TcpStream :: connect ( "host:1234" ) . await . unwrap ( ) ;
301
- let mut s = AsyncBincodeStream :: < _ , Message , Message , _ > :: from ( s) . for_async ( ) ;
302
-
303
- s. send ( Message :: Node ( NodeMessage :: Handshake {
304
- protocol_version : 1234 ,
305
- node_id : Uuid :: new_v4 ( ) ,
306
- } ) )
307
- . await
308
- . unwrap ( ) ;
310
+ let mut s = AsyncBincodeStream :: < _ , Enveloppe , Enveloppe , _ > :: from ( s) . for_async ( ) ;
311
+
312
+ let msg = Enveloppe {
313
+ database_id : None ,
314
+ message : Message :: Handshake {
315
+ protocol_version : 1234 ,
316
+ node_id : 1 ,
317
+ } ,
318
+ } ;
319
+ s. send ( msg) . await . unwrap ( ) ;
309
320
let m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ;
310
321
311
322
assert ! ( matches!(
312
- m,
313
- Message :: Node ( NodeMessage :: Error (
314
- NodeError :: HandshakeVersionMismatch { .. }
315
- ) )
323
+ m. message ,
324
+ Message :: Error (
325
+ ProtoError :: HandshakeVersionMismatch { .. }
326
+ )
316
327
) ) ;
317
328
318
- Ok ( ( ) )
319
- } ) ;
320
-
321
- sim. run ( ) . unwrap ( ) ;
322
- }
323
-
324
- #[ test]
325
- fn stream_closed ( ) {
326
- let mut sim = turmoil:: Builder :: new ( ) . build ( ) ;
327
-
328
- let database_id = DatabaseId :: new_v4 ( ) ;
329
- let host_node_id = NodeId :: new_v4 ( ) ;
330
- let notify = Arc :: new ( Notify :: new ( ) ) ;
331
- sim. host ( "host" , {
332
- let notify = notify. clone ( ) ;
333
- move || {
334
- let notify = notify. clone ( ) ;
335
- async move {
336
- let bus = Bus :: new ( host_node_id) ;
337
- let mut sub = bus. subscribe ( database_id) . unwrap ( ) ;
338
- let listener = turmoil:: net:: TcpListener :: bind ( "0.0.0.0:1234" )
339
- . await
340
- . unwrap ( ) ;
341
- let ( s, _) = listener. accept ( ) . await . unwrap ( ) ;
342
- let connection = Connection :: new_acceptor ( s, bus) ;
343
- tokio:: task:: spawn_local ( connection. run ( ) ) ;
344
- let mut streams = Vec :: new ( ) ;
345
- loop {
346
- tokio:: select! {
347
- Some ( mut stream) = sub. next( ) => {
348
- let m = stream. next( ) . await . unwrap( ) ;
349
- stream. send( m) . await . unwrap( ) ;
350
- streams. push( stream) ;
351
- }
352
- _ = notify. notified( ) => {
353
- break ;
354
- }
355
- }
356
- }
357
-
358
- Ok ( ( ) )
359
- }
360
- }
361
- } ) ;
362
-
363
- sim. client ( "client" , async move {
364
- let stream_id = StreamId :: new ( 1 ) ;
365
- let node_id = NodeId :: new_v4 ( ) ;
366
- let s = TcpStream :: connect ( "host:1234" ) . await . unwrap ( ) ;
367
- let mut s = AsyncBincodeStream :: < _ , Message , Message , _ > :: from ( s) . for_async ( ) ;
368
-
369
- s. send ( Message :: Node ( NodeMessage :: Handshake {
370
- protocol_version : CURRENT_PROTO_VERSION ,
371
- node_id,
372
- } ) )
373
- . await
374
- . unwrap ( ) ;
375
- let m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ;
376
- assert ! ( matches!( m, Message :: Node ( NodeMessage :: Handshake { .. } ) ) ) ;
377
-
378
- // send message to unexisting stream:
379
- s. send ( Message :: Stream {
380
- stream_id,
381
- payload : StreamMessage :: Dummy ,
382
- } )
383
- . await
384
- . unwrap ( ) ;
385
- let m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ;
386
- assert_eq ! (
387
- m,
388
- Message :: Node ( NodeMessage :: Error ( NodeError :: UnknownStream ( stream_id) ) )
389
- ) ;
390
-
391
- // open stream then send message
392
- s. send ( Message :: Node ( NodeMessage :: OpenStream {
393
- stream_id,
394
- database_id,
395
- } ) )
396
- . await
397
- . unwrap ( ) ;
398
- s. send ( Message :: Stream {
399
- stream_id,
400
- payload : StreamMessage :: Dummy ,
401
- } )
402
- . await
403
- . unwrap ( ) ;
404
- let m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ;
405
- assert_eq ! (
406
- m,
407
- Message :: Stream {
408
- stream_id,
409
- payload: StreamMessage :: Dummy
410
- }
411
- ) ;
412
-
413
- s. send ( Message :: Node ( NodeMessage :: CloseStream {
414
- stream_id : StreamId :: new ( 1 ) ,
415
- } ) )
416
- . await
417
- . unwrap ( ) ;
418
- s. send ( Message :: Stream {
419
- stream_id,
420
- payload : StreamMessage :: Dummy ,
421
- } )
422
- . await
423
- . unwrap ( ) ;
424
- let m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ;
425
- assert_eq ! (
426
- m,
427
- Message :: Node ( NodeMessage :: Error ( NodeError :: UnknownStream ( stream_id) ) )
428
- ) ;
429
-
430
- notify. notify_waiters ( ) ;
329
+ done. notify_waiters ( ) ;
431
330
432
331
Ok ( ( ) )
433
332
} ) ;
@@ -459,7 +358,7 @@ mod test {
459
358
460
359
sim. client ( "client" , async move {
461
360
let stream = TcpStream :: connect ( "host:1234" ) . await . unwrap ( ) ;
462
- let bus = Bus :: new ( NodeId :: new_v4 ( ) ) ;
361
+ let bus = Arc :: new ( Bus :: new ( 1 , |_ , _| async { } ) ) ;
463
362
let mut conn = Connection :: new_acceptor ( stream, bus) ;
464
363
465
364
notify. notify_waiters ( ) ;
@@ -473,57 +372,4 @@ mod test {
473
372
474
373
sim. run ( ) . unwrap ( ) ;
475
374
}
476
-
477
- #[ test]
478
- fn zero_stream_id ( ) {
479
- let mut sim = turmoil:: Builder :: new ( ) . build ( ) ;
480
-
481
- let notify = Arc :: new ( Notify :: new ( ) ) ;
482
- sim. host ( "host" , {
483
- let notify = notify. clone ( ) ;
484
- move || {
485
- let notify = notify. clone ( ) ;
486
- async move {
487
- let listener = TcpListener :: bind ( "0.0.0.0:1234" ) . await . unwrap ( ) ;
488
- let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
489
- let ( connection_messages_sender, connection_messages) = mpsc:: channel ( 1 ) ;
490
- let conn = Connection {
491
- peer : Some ( NodeId :: new_v4 ( ) ) ,
492
- state : ConnectionState :: Connected ,
493
- conn : AsyncBincodeStream :: from ( stream) . for_async ( ) ,
494
- streams : HashMap :: new ( ) ,
495
- connection_messages,
496
- connection_messages_sender,
497
- is_initiator : false ,
498
- bus : Bus :: new ( NodeId :: new_v4 ( ) ) ,
499
- stream_id_allocator : StreamIdAllocator :: new ( false ) ,
500
- registration : None ,
501
- } ;
502
-
503
- conn. run ( ) . await ;
504
-
505
- Ok ( ( ) )
506
- }
507
- }
508
- } ) ;
509
-
510
- sim. client ( "client" , async move {
511
- let stream = TcpStream :: connect ( "host:1234" ) . await . unwrap ( ) ;
512
- let mut stream = AsyncBincodeStream :: < _ , Message , Message , _ > :: from ( stream) . for_async ( ) ;
513
-
514
- stream
515
- . send ( Message :: Stream {
516
- stream_id : StreamId :: new_unchecked ( 0 ) ,
517
- payload : StreamMessage :: Dummy ,
518
- } )
519
- . await
520
- . unwrap ( ) ;
521
-
522
- assert ! ( stream. next( ) . await . is_none( ) ) ;
523
-
524
- Ok ( ( ) )
525
- } ) ;
526
-
527
- sim. run ( ) . unwrap ( ) ;
528
- }
529
375
}
0 commit comments