1717
1818
1919import inspect
20+ from collections import defaultdict
2021
2122import pytest
2223
5051ROUTER1_ADDRESS = ResolvedAddress (("1.2.3.1" , 9000 ), host_name = "host" )
5152ROUTER2_ADDRESS = ResolvedAddress (("1.2.3.1" , 9001 ), host_name = "host" )
5253ROUTER3_ADDRESS = ResolvedAddress (("1.2.3.1" , 9002 ), host_name = "host" )
53- READER_ADDRESS = ResolvedAddress (("1.2.3.1" , 9010 ), host_name = "host" )
54- WRITER_ADDRESS = ResolvedAddress (("1.2.3.1" , 9020 ), host_name = "host" )
54+ READER1_ADDRESS = ResolvedAddress (("1.2.3.1" , 9010 ), host_name = "host" )
55+ READER2_ADDRESS = ResolvedAddress (("1.2.3.1" , 9011 ), host_name = "host" )
56+ READER3_ADDRESS = ResolvedAddress (("1.2.3.1" , 9012 ), host_name = "host" )
57+ WRITER1_ADDRESS = ResolvedAddress (("1.2.3.1" , 9020 ), host_name = "host" )
5558
5659
5760@pytest .fixture
58- def routing_failure_opener (async_fake_connection_generator , mocker ):
59- def make_opener (failures = None ):
61+ def custom_routing_opener (async_fake_connection_generator , mocker ):
62+ def make_opener (failures = None , get_readers = None ):
6063 def routing_side_effect (* args , ** kwargs ):
6164 nonlocal failures
6265 res = next (failures , None )
6366 if res is None :
67+ if get_readers is not None :
68+ readers = get_readers (kwargs .get ("database" ))
69+ else :
70+ readers = [str (READER1_ADDRESS )]
6471 return [{
6572 "ttl" : 1000 ,
6673 "servers" : [
6774 {"addresses" : [str (ROUTER1_ADDRESS ),
6875 str (ROUTER2_ADDRESS ),
6976 str (ROUTER3_ADDRESS )],
7077 "role" : "ROUTE" },
71- {"addresses" : [ str ( READER_ADDRESS )] , "role" : "READ" },
72- {"addresses" : [str (WRITER_ADDRESS )], "role" : "WRITE" },
78+ {"addresses" : readers , "role" : "READ" },
79+ {"addresses" : [str (WRITER1_ADDRESS )], "role" : "WRITE" },
7380 ],
7481 }]
7582 raise res
@@ -96,8 +103,8 @@ async def open_(addr, auth, timeout):
96103
97104
98105@pytest .fixture
99- def opener (routing_failure_opener ):
100- return routing_failure_opener ()
106+ def opener (custom_routing_opener ):
107+ return custom_routing_opener ()
101108
102109
103110def _pool_config ():
@@ -177,9 +184,9 @@ async def test_chooses_right_connection_type(opener, type_):
177184 )
178185 await pool .release (cx1 )
179186 if type_ == "r" :
180- assert cx1 .unresolved_address == READER_ADDRESS
187+ assert cx1 .unresolved_address == READER1_ADDRESS
181188 else :
182- assert cx1 .unresolved_address == WRITER_ADDRESS
189+ assert cx1 .unresolved_address == WRITER1_ADDRESS
183190
184191
185192@mark_async_test
@@ -298,9 +305,9 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection(
298305 opener , liveness_timeout
299306):
300307 pool = _simple_pool (opener )
301- cx1 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
308+ cx1 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
302309 liveness_timeout )
303- assert cx1 .unresolved_address == READER_ADDRESS
310+ assert cx1 .unresolved_address == READER1_ADDRESS
304311 cx1 .reset .assert_not_called ()
305312
306313
@@ -311,11 +318,11 @@ async def test_acquire_performs_liveness_check_on_existing_connection(
311318):
312319 pool = _simple_pool (opener )
313320 # populate the pool with a connection
314- cx1 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
321+ cx1 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
315322 liveness_timeout )
316323
317324 # make sure we assume the right state
318- assert cx1 .unresolved_address == READER_ADDRESS
325+ assert cx1 .unresolved_address == READER1_ADDRESS
319326 cx1 .is_idle_for .assert_not_called ()
320327 cx1 .reset .assert_not_called ()
321328
@@ -326,7 +333,7 @@ async def test_acquire_performs_liveness_check_on_existing_connection(
326333 cx1 .reset .assert_not_called ()
327334
328335 # then acquire it again and assert the liveness check was performed
329- cx2 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
336+ cx2 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
330337 liveness_timeout )
331338 assert cx1 is cx2
332339 cx1 .is_idle_for .assert_called_once_with (liveness_timeout )
@@ -345,11 +352,11 @@ def liveness_side_effect(*args, **kwargs):
345352 liveness_timeout = 1
346353 pool = _simple_pool (opener )
347354 # populate the pool with a connection
348- cx1 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
355+ cx1 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
349356 liveness_timeout )
350357
351358 # make sure we assume the right state
352- assert cx1 .unresolved_address == READER_ADDRESS
359+ assert cx1 .unresolved_address == READER1_ADDRESS
353360 cx1 .is_idle_for .assert_not_called ()
354361 cx1 .reset .assert_not_called ()
355362
@@ -362,7 +369,7 @@ def liveness_side_effect(*args, **kwargs):
362369 cx1 .reset .assert_not_called ()
363370
364371 # then acquire it again and assert the liveness check was performed
365- cx2 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
372+ cx2 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
366373 liveness_timeout )
367374 assert cx1 is not cx2
368375 assert cx1 .unresolved_address == cx2 .unresolved_address
@@ -384,14 +391,14 @@ def liveness_side_effect(*args, **kwargs):
384391 liveness_timeout = 1
385392 pool = _simple_pool (opener )
386393 # populate the pool with a connection
387- cx1 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
394+ cx1 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
388395 liveness_timeout )
389- cx2 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
396+ cx2 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
390397 liveness_timeout )
391398
392399 # make sure we assume the right state
393- assert cx1 .unresolved_address == READER_ADDRESS
394- assert cx2 .unresolved_address == READER_ADDRESS
400+ assert cx1 .unresolved_address == READER1_ADDRESS
401+ assert cx2 .unresolved_address == READER1_ADDRESS
395402 assert cx1 is not cx2
396403 cx1 .is_idle_for .assert_not_called ()
397404 cx2 .is_idle_for .assert_not_called ()
@@ -409,7 +416,7 @@ def liveness_side_effect(*args, **kwargs):
409416 cx2 .reset .assert_not_called ()
410417
411418 # then acquire it again and assert the liveness check was performed
412- cx3 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
419+ cx3 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
413420 liveness_timeout )
414421 assert cx3 is cx2
415422 cx1 .is_idle_for .assert_called_once_with (liveness_timeout )
@@ -426,7 +433,7 @@ def mock_connection_breaks_on_close(cx):
426433 async def close_side_effect ():
427434 cx .closed .return_value = True
428435 cx .defunct .return_value = True
429- await pool .deactivate (READER_ADDRESS )
436+ await pool .deactivate (READER1_ADDRESS )
430437
431438 cx .attach_mock (mocker .AsyncMock (side_effect = close_side_effect ),
432439 "close" )
@@ -470,9 +477,9 @@ async def test__acquire_new_later_with_room(opener):
470477 pool = AsyncNeo4jPool (
471478 opener , config , WorkspaceConfig (), ROUTER1_ADDRESS
472479 )
473- assert pool .connections_reservations [READER_ADDRESS ] == 0
474- creator = pool ._acquire_new_later (READER_ADDRESS , None , Deadline (1 ))
475- assert pool .connections_reservations [READER_ADDRESS ] == 1
480+ assert pool .connections_reservations [READER1_ADDRESS ] == 0
481+ creator = pool ._acquire_new_later (READER1_ADDRESS , None , Deadline (1 ))
482+ assert pool .connections_reservations [READER1_ADDRESS ] == 1
476483 assert callable (creator )
477484 if AsyncUtil .is_async_code :
478485 assert inspect .iscoroutinefunction (creator )
@@ -487,9 +494,9 @@ async def test__acquire_new_later_without_room(opener):
487494 )
488495 _ = await pool .acquire (READ_ACCESS , 30 , "test_db" , None , None , None )
489496 # pool is full now
490- assert pool .connections_reservations [READER_ADDRESS ] == 0
491- creator = pool ._acquire_new_later (READER_ADDRESS , None , Deadline (1 ))
492- assert pool .connections_reservations [READER_ADDRESS ] == 0
497+ assert pool .connections_reservations [READER1_ADDRESS ] == 0
498+ creator = pool ._acquire_new_later (READER1_ADDRESS , None , Deadline (1 ))
499+ assert pool .connections_reservations [READER1_ADDRESS ] == 0
493500 assert creator is None
494501
495502
@@ -519,8 +526,8 @@ async def test_passes_pool_config_to_connection(mocker):
519526 "Neo.ClientError.Security.AuthorizationExpired" ),
520527))
521528@mark_async_test
522- async def test_discovery_is_retried (routing_failure_opener , error ):
523- opener = routing_failure_opener ([
529+ async def test_discovery_is_retried (custom_routing_opener , error ):
530+ opener = custom_routing_opener ([
524531 None , # first call to router for seeding the RT with more routers
525532 error , # will be retried
526533 ])
@@ -563,8 +570,8 @@ async def test_discovery_is_retried(routing_failure_opener, error):
563570 )
564571))
565572@mark_async_test
566- async def test_fast_failing_discovery (routing_failure_opener , error ):
567- opener = routing_failure_opener ([
573+ async def test_fast_failing_discovery (custom_routing_opener , error ):
574+ opener = custom_routing_opener ([
568575 None , # first call to router for seeding the RT with more routers
569576 error , # will be retried
570577 ])
@@ -648,3 +655,85 @@ async def test_connection_error_callback(
648655 cx .mark_unauthenticated .assert_not_called ()
649656 for cx in cxs_write :
650657 cx .mark_unauthenticated .assert_not_called ()
658+
659+
660+ @mark_async_test
661+ async def test_pool_closes_connections_dropped_from_rt (custom_routing_opener ):
662+ readers = {"db1" : [str (READER1_ADDRESS )]}
663+
664+ def get_readers (database ):
665+ return readers [database ]
666+
667+ opener = custom_routing_opener (get_readers = get_readers )
668+
669+ pool = AsyncNeo4jPool (
670+ opener , _pool_config (), WorkspaceConfig (), ROUTER1_ADDRESS
671+ )
672+ cx1 = await pool .acquire (READ_ACCESS , 30 , "db1" , None , None , None )
673+ assert cx1 .unresolved_address == READER1_ADDRESS
674+ await pool .release (cx1 )
675+
676+ cx1 .close .assert_not_called ()
677+ assert len (pool .connections [READER1_ADDRESS ]) == 1
678+
679+ # force RT refresh, returning a different reader
680+ del pool .routing_tables ["db1" ]
681+ readers ["db1" ] = [str (READER2_ADDRESS )]
682+
683+ cx2 = await pool .acquire (READ_ACCESS , 30 , "db1" , None , None , None )
684+ assert cx2 .unresolved_address == READER2_ADDRESS
685+
686+ cx1 .close .assert_awaited_once ()
687+ assert len (pool .connections [READER1_ADDRESS ]) == 0
688+
689+ await pool .release (cx2 )
690+ assert len (pool .connections [READER2_ADDRESS ]) == 1
691+
692+
693+ @mark_async_test
694+ async def test_pool_does_not_close_connections_dropped_from_rt_for_other_server (
695+ custom_routing_opener
696+ ):
697+ readers = {
698+ "db1" : [str (READER1_ADDRESS ), str (READER2_ADDRESS )],
699+ "db2" : [str (READER1_ADDRESS )]
700+ }
701+
702+ def get_readers (database ):
703+ return readers [database ]
704+
705+ opener = custom_routing_opener (get_readers = get_readers )
706+
707+ pool = AsyncNeo4jPool (
708+ opener , _pool_config (), WorkspaceConfig (), ROUTER1_ADDRESS
709+ )
710+ cx1 = await pool .acquire (READ_ACCESS , 30 , "db1" , None , None , None )
711+ await pool .release (cx1 )
712+ assert cx1 .unresolved_address in (READER1_ADDRESS , READER2_ADDRESS )
713+ reader1_connection_count = len (pool .connections [READER1_ADDRESS ])
714+ reader2_connection_count = len (pool .connections [READER2_ADDRESS ])
715+ assert reader1_connection_count + reader2_connection_count == 1
716+
717+ cx2 = await pool .acquire (READ_ACCESS , 30 , "db2" , None , None , None )
718+ await pool .release (cx2 )
719+ assert cx2 .unresolved_address == READER1_ADDRESS
720+ cx1 .close .assert_not_called ()
721+ cx2 .close .assert_not_called ()
722+ assert len (pool .connections [READER1_ADDRESS ]) == 1
723+ assert len (pool .connections [READER2_ADDRESS ]) == reader2_connection_count
724+
725+
726+ # force RT refresh, returning a different reader
727+ del pool .routing_tables ["db2" ]
728+ readers ["db2" ] = [str (READER3_ADDRESS )]
729+
730+ cx3 = await pool .acquire (READ_ACCESS , 30 , "db2" , None , None , None )
731+ await pool .release (cx3 )
732+ assert cx3 .unresolved_address == READER3_ADDRESS
733+
734+ cx1 .close .assert_not_called ()
735+ cx2 .close .assert_not_called ()
736+ cx3 .close .assert_not_called ()
737+ assert len (pool .connections [READER1_ADDRESS ]) == 1
738+ assert len (pool .connections [READER2_ADDRESS ]) == reader2_connection_count
739+ assert len (pool .connections [READER3_ADDRESS ]) == 1
0 commit comments