diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index e854fe987..619a662f1 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -813,8 +813,13 @@ async def update_routing_table( raise ServiceUnavailable("Unable to retrieve routing information") async def update_connection_pool(self, *, database): - routing_table = await self.get_or_create_routing_table(database) - servers = routing_table.servers() + async with self.refresh_lock: + routing_tables = [await self.get_or_create_routing_table(database)] + for db in self.routing_tables.keys(): + if db == database: + continue + routing_tables.append(self.routing_tables[db]) + servers = set.union(*(rt.servers() for rt in routing_tables)) for address in list(self.connections): if address._unresolved not in servers: await super(AsyncNeo4jPool, self).deactivate(address) @@ -960,6 +965,7 @@ async def deactivate(self, address): async def on_write_failure(self, address): """ Remove a writer address from the routing table, if present. """ + # FIXME: only need to remove the writer for a specific database log.debug("[#0000] _: removing writer %r", address) async with self.refresh_lock: for database in self.routing_tables.keys(): diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index b8adff017..73f77944a 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -810,8 +810,13 @@ def update_routing_table( raise ServiceUnavailable("Unable to retrieve routing information") def update_connection_pool(self, *, database): - routing_table = self.get_or_create_routing_table(database) - servers = routing_table.servers() + with self.refresh_lock: + routing_tables = [self.get_or_create_routing_table(database)] + for db in self.routing_tables.keys(): + if db == database: + continue + routing_tables.append(self.routing_tables[db]) + servers = set.union(*(rt.servers() for rt in routing_tables)) for address in list(self.connections): if address._unresolved not in servers: super(Neo4jPool, self).deactivate(address) @@ -957,6 +962,7 @@ def deactivate(self, address): def on_write_failure(self, address): """ Remove a writer address from the routing table, if present. """ + # FIXME: only need to remove the writer for a specific database log.debug("[#0000] _: removing writer %r", address) with self.refresh_lock: for database in self.routing_tables.keys(): diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index ca27d92f3..a8249eed6 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -17,6 +17,7 @@ import inspect +from collections import defaultdict import pytest @@ -50,17 +51,23 @@ ROUTER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9000), host_name="host") ROUTER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") ROUTER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") -READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host") -WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") +READER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host") +READER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9011), host_name="host") +READER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9012), host_name="host") +WRITER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") @pytest.fixture -def routing_failure_opener(async_fake_connection_generator, mocker): - def make_opener(failures=None): +def custom_routing_opener(async_fake_connection_generator, mocker): + def make_opener(failures=None, get_readers=None): def routing_side_effect(*args, **kwargs): nonlocal failures res = next(failures, None) if res is None: + if get_readers is not None: + readers = get_readers(kwargs.get("database")) + else: + readers = [str(READER1_ADDRESS)] return [{ "ttl": 1000, "servers": [ @@ -68,8 +75,8 @@ def routing_side_effect(*args, **kwargs): str(ROUTER2_ADDRESS), str(ROUTER3_ADDRESS)], "role": "ROUTE"}, - {"addresses": [str(READER_ADDRESS)], "role": "READ"}, - {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, + {"addresses": readers, "role": "READ"}, + {"addresses": [str(WRITER1_ADDRESS)], "role": "WRITE"}, ], }] raise res @@ -96,8 +103,8 @@ async def open_(addr, auth, timeout): @pytest.fixture -def opener(routing_failure_opener): - return routing_failure_opener() +def opener(custom_routing_opener): + return custom_routing_opener() def _pool_config(): @@ -177,9 +184,9 @@ async def test_chooses_right_connection_type(opener, type_): ) await pool.release(cx1) if type_ == "r": - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS else: - assert cx1.unresolved_address == WRITER_ADDRESS + assert cx1.unresolved_address == WRITER1_ADDRESS @mark_async_test @@ -298,9 +305,9 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection( opener, liveness_timeout ): pool = _simple_pool(opener) - cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS cx1.reset.assert_not_called() @@ -311,11 +318,11 @@ async def test_acquire_performs_liveness_check_on_existing_connection( ): pool = _simple_pool(opener) # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -326,7 +333,7 @@ async def test_acquire_performs_liveness_check_on_existing_connection( cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx2 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) assert cx1 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) @@ -345,11 +352,11 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout = 1 pool = _simple_pool(opener) # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -362,7 +369,7 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx2 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) assert cx1 is not cx2 assert cx1.unresolved_address == cx2.unresolved_address @@ -384,14 +391,14 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout = 1 pool = _simple_pool(opener) # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) - cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx2 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.unresolved_address == READER_ADDRESS - assert cx2.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS + assert cx2.unresolved_address == READER1_ADDRESS assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() @@ -409,7 +416,7 @@ def liveness_side_effect(*args, **kwargs): cx2.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx3 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + cx3 = await pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) assert cx3 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) @@ -426,7 +433,7 @@ def mock_connection_breaks_on_close(cx): async def close_side_effect(): cx.closed.return_value = True cx.defunct.return_value = True - await pool.deactivate(READER_ADDRESS) + await pool.deactivate(READER1_ADDRESS) cx.attach_mock(mocker.AsyncMock(side_effect=close_side_effect), "close") @@ -470,9 +477,9 @@ async def test__acquire_new_later_with_room(opener): pool = AsyncNeo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) - assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) - assert pool.connections_reservations[READER_ADDRESS] == 1 + assert pool.connections_reservations[READER1_ADDRESS] == 0 + creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1)) + assert pool.connections_reservations[READER1_ADDRESS] == 1 assert callable(creator) if AsyncUtil.is_async_code: assert inspect.iscoroutinefunction(creator) @@ -487,9 +494,9 @@ async def test__acquire_new_later_without_room(opener): ) _ = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) # pool is full now - assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) - assert pool.connections_reservations[READER_ADDRESS] == 0 + assert pool.connections_reservations[READER1_ADDRESS] == 0 + creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1)) + assert pool.connections_reservations[READER1_ADDRESS] == 0 assert creator is None @@ -519,8 +526,8 @@ async def test_passes_pool_config_to_connection(mocker): "Neo.ClientError.Security.AuthorizationExpired"), )) @mark_async_test -async def test_discovery_is_retried(routing_failure_opener, error): - opener = routing_failure_opener([ +async def test_discovery_is_retried(custom_routing_opener, error): + opener = custom_routing_opener([ None, # first call to router for seeding the RT with more routers error, # will be retried ]) @@ -563,8 +570,8 @@ async def test_discovery_is_retried(routing_failure_opener, error): ) )) @mark_async_test -async def test_fast_failing_discovery(routing_failure_opener, error): - opener = routing_failure_opener([ +async def test_fast_failing_discovery(custom_routing_opener, error): + opener = custom_routing_opener([ None, # first call to router for seeding the RT with more routers error, # will be retried ]) @@ -648,3 +655,85 @@ async def test_connection_error_callback( cx.mark_unauthenticated.assert_not_called() for cx in cxs_write: cx.mark_unauthenticated.assert_not_called() + + +@mark_async_test +async def test_pool_closes_connections_dropped_from_rt(custom_routing_opener): + readers = {"db1": [str(READER1_ADDRESS)]} + + def get_readers(database): + return readers[database] + + opener = custom_routing_opener(get_readers=get_readers) + + pool = AsyncNeo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + assert cx1.unresolved_address == READER1_ADDRESS + await pool.release(cx1) + + cx1.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + + # force RT refresh, returning a different reader + del pool.routing_tables["db1"] + readers["db1"] = [str(READER2_ADDRESS)] + + cx2 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + assert cx2.unresolved_address == READER2_ADDRESS + + cx1.close.assert_awaited_once() + assert len(pool.connections[READER1_ADDRESS]) == 0 + + await pool.release(cx2) + assert len(pool.connections[READER2_ADDRESS]) == 1 + + +@mark_async_test +async def test_pool_does_not_close_connections_dropped_from_rt_for_other_server( + custom_routing_opener +): + readers = { + "db1": [str(READER1_ADDRESS), str(READER2_ADDRESS)], + "db2": [str(READER1_ADDRESS)] + } + + def get_readers(database): + return readers[database] + + opener = custom_routing_opener(get_readers=get_readers) + + pool = AsyncNeo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + await pool.release(cx1) + assert cx1.unresolved_address in (READER1_ADDRESS, READER2_ADDRESS) + reader1_connection_count = len(pool.connections[READER1_ADDRESS]) + reader2_connection_count = len(pool.connections[READER2_ADDRESS]) + assert reader1_connection_count + reader2_connection_count == 1 + + cx2 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + await pool.release(cx2) + assert cx2.unresolved_address == READER1_ADDRESS + cx1.close.assert_not_called() + cx2.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count + + + # force RT refresh, returning a different reader + del pool.routing_tables["db2"] + readers["db2"] = [str(READER3_ADDRESS)] + + cx3 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + await pool.release(cx3) + assert cx3.unresolved_address == READER3_ADDRESS + + cx1.close.assert_not_called() + cx2.close.assert_not_called() + cx3.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count + assert len(pool.connections[READER3_ADDRESS]) == 1 diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index cfaaf1f34..3a5a2e79b 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -17,6 +17,7 @@ import inspect +from collections import defaultdict import pytest @@ -50,17 +51,23 @@ ROUTER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9000), host_name="host") ROUTER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") ROUTER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") -READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host") -WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") +READER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host") +READER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9011), host_name="host") +READER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9012), host_name="host") +WRITER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") @pytest.fixture -def routing_failure_opener(fake_connection_generator, mocker): - def make_opener(failures=None): +def custom_routing_opener(fake_connection_generator, mocker): + def make_opener(failures=None, get_readers=None): def routing_side_effect(*args, **kwargs): nonlocal failures res = next(failures, None) if res is None: + if get_readers is not None: + readers = get_readers(kwargs.get("database")) + else: + readers = [str(READER1_ADDRESS)] return [{ "ttl": 1000, "servers": [ @@ -68,8 +75,8 @@ def routing_side_effect(*args, **kwargs): str(ROUTER2_ADDRESS), str(ROUTER3_ADDRESS)], "role": "ROUTE"}, - {"addresses": [str(READER_ADDRESS)], "role": "READ"}, - {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, + {"addresses": readers, "role": "READ"}, + {"addresses": [str(WRITER1_ADDRESS)], "role": "WRITE"}, ], }] raise res @@ -96,8 +103,8 @@ def open_(addr, auth, timeout): @pytest.fixture -def opener(routing_failure_opener): - return routing_failure_opener() +def opener(custom_routing_opener): + return custom_routing_opener() def _pool_config(): @@ -177,9 +184,9 @@ def test_chooses_right_connection_type(opener, type_): ) pool.release(cx1) if type_ == "r": - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS else: - assert cx1.unresolved_address == WRITER_ADDRESS + assert cx1.unresolved_address == WRITER1_ADDRESS @mark_sync_test @@ -298,9 +305,9 @@ def test_acquire_performs_no_liveness_check_on_fresh_connection( opener, liveness_timeout ): pool = _simple_pool(opener) - cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS cx1.reset.assert_not_called() @@ -311,11 +318,11 @@ def test_acquire_performs_liveness_check_on_existing_connection( ): pool = _simple_pool(opener) # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -326,7 +333,7 @@ def test_acquire_performs_liveness_check_on_existing_connection( cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx2 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) assert cx1 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) @@ -345,11 +352,11 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout = 1 pool = _simple_pool(opener) # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -362,7 +369,7 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx2 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) assert cx1 is not cx2 assert cx1.unresolved_address == cx2.unresolved_address @@ -384,14 +391,14 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout = 1 pool = _simple_pool(opener) # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx1 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) - cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx2 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.unresolved_address == READER_ADDRESS - assert cx2.unresolved_address == READER_ADDRESS + assert cx1.unresolved_address == READER1_ADDRESS + assert cx2.unresolved_address == READER1_ADDRESS assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() @@ -409,7 +416,7 @@ def liveness_side_effect(*args, **kwargs): cx2.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx3 = pool._acquire(READER_ADDRESS, None, Deadline(30), + cx3 = pool._acquire(READER1_ADDRESS, None, Deadline(30), liveness_timeout) assert cx3 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) @@ -426,7 +433,7 @@ def mock_connection_breaks_on_close(cx): def close_side_effect(): cx.closed.return_value = True cx.defunct.return_value = True - pool.deactivate(READER_ADDRESS) + pool.deactivate(READER1_ADDRESS) cx.attach_mock(mocker.MagicMock(side_effect=close_side_effect), "close") @@ -470,9 +477,9 @@ def test__acquire_new_later_with_room(opener): pool = Neo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) - assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) - assert pool.connections_reservations[READER_ADDRESS] == 1 + assert pool.connections_reservations[READER1_ADDRESS] == 0 + creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1)) + assert pool.connections_reservations[READER1_ADDRESS] == 1 assert callable(creator) if Util.is_async_code: assert inspect.iscoroutinefunction(creator) @@ -487,9 +494,9 @@ def test__acquire_new_later_without_room(opener): ) _ = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) # pool is full now - assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) - assert pool.connections_reservations[READER_ADDRESS] == 0 + assert pool.connections_reservations[READER1_ADDRESS] == 0 + creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1)) + assert pool.connections_reservations[READER1_ADDRESS] == 0 assert creator is None @@ -519,8 +526,8 @@ def test_passes_pool_config_to_connection(mocker): "Neo.ClientError.Security.AuthorizationExpired"), )) @mark_sync_test -def test_discovery_is_retried(routing_failure_opener, error): - opener = routing_failure_opener([ +def test_discovery_is_retried(custom_routing_opener, error): + opener = custom_routing_opener([ None, # first call to router for seeding the RT with more routers error, # will be retried ]) @@ -563,8 +570,8 @@ def test_discovery_is_retried(routing_failure_opener, error): ) )) @mark_sync_test -def test_fast_failing_discovery(routing_failure_opener, error): - opener = routing_failure_opener([ +def test_fast_failing_discovery(custom_routing_opener, error): + opener = custom_routing_opener([ None, # first call to router for seeding the RT with more routers error, # will be retried ]) @@ -648,3 +655,85 @@ def test_connection_error_callback( cx.mark_unauthenticated.assert_not_called() for cx in cxs_write: cx.mark_unauthenticated.assert_not_called() + + +@mark_sync_test +def test_pool_closes_connections_dropped_from_rt(custom_routing_opener): + readers = {"db1": [str(READER1_ADDRESS)]} + + def get_readers(database): + return readers[database] + + opener = custom_routing_opener(get_readers=get_readers) + + pool = Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + assert cx1.unresolved_address == READER1_ADDRESS + pool.release(cx1) + + cx1.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + + # force RT refresh, returning a different reader + del pool.routing_tables["db1"] + readers["db1"] = [str(READER2_ADDRESS)] + + cx2 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + assert cx2.unresolved_address == READER2_ADDRESS + + cx1.close.assert_called_once() + assert len(pool.connections[READER1_ADDRESS]) == 0 + + pool.release(cx2) + assert len(pool.connections[READER2_ADDRESS]) == 1 + + +@mark_sync_test +def test_pool_does_not_close_connections_dropped_from_rt_for_other_server( + custom_routing_opener +): + readers = { + "db1": [str(READER1_ADDRESS), str(READER2_ADDRESS)], + "db2": [str(READER1_ADDRESS)] + } + + def get_readers(database): + return readers[database] + + opener = custom_routing_opener(get_readers=get_readers) + + pool = Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + pool.release(cx1) + assert cx1.unresolved_address in (READER1_ADDRESS, READER2_ADDRESS) + reader1_connection_count = len(pool.connections[READER1_ADDRESS]) + reader2_connection_count = len(pool.connections[READER2_ADDRESS]) + assert reader1_connection_count + reader2_connection_count == 1 + + cx2 = pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + pool.release(cx2) + assert cx2.unresolved_address == READER1_ADDRESS + cx1.close.assert_not_called() + cx2.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count + + + # force RT refresh, returning a different reader + del pool.routing_tables["db2"] + readers["db2"] = [str(READER3_ADDRESS)] + + cx3 = pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + pool.release(cx3) + assert cx3.unresolved_address == READER3_ADDRESS + + cx1.close.assert_not_called() + cx2.close.assert_not_called() + cx3.close.assert_not_called() + assert len(pool.connections[READER1_ADDRESS]) == 1 + assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count + assert len(pool.connections[READER3_ADDRESS]) == 1