diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index bef051054..2435de58a 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -492,6 +492,25 @@ def in_use_connection_count(self, address): else: return sum(1 if connection.in_use else 0 for connection in connections) + def deactivate(self, address): + """ Deactivate an address from the connection pool, if present, closing + all idle connection to that address + """ + with self.lock: + try: + connections = self.connections[address] + except KeyError: # already removed from the connection pool + return + for conn in list(connections): + if not conn.in_use: + connections.remove(conn) + try: + conn.close() + except IOError: + pass + if not connections: + self.remove(address) + def remove(self, address): """ Remove an address from the connection pool, if present, closing all connections to that address. diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 4f9fb33fb..a36815add 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -140,6 +140,9 @@ def update(self, new_routing_table): self.last_updated_time = self.timer() self.ttl = new_routing_table.ttl + def servers(self): + return set(self.routers) | set(self.writers) | set(self.readers) + class RoutingSession(BoltSession): @@ -249,9 +252,9 @@ class RoutingConnectionErrorHandler(ConnectionErrorHandler): def __init__(self, pool): super(RoutingConnectionErrorHandler, self).__init__({ - SessionExpired: lambda address: pool.remove(address), - ServiceUnavailable: lambda address: pool.remove(address), - DatabaseUnavailableError: lambda address: pool.remove(address), + SessionExpired: lambda address: pool.deactivate(address), + ServiceUnavailable: lambda address: pool.deactivate(address), + DatabaseUnavailableError: lambda address: pool.deactivate(address), NotALeaderError: lambda address: pool.remove_writer(address), ForbiddenOnReadOnlyDatabaseError: lambda address: pool.remove_writer(address) }) @@ -288,7 +291,7 @@ def fetch_routing_info(self, address): else: raise ServiceUnavailable("Routing support broken on server {!r}".format(address)) except ServiceUnavailable: - self.remove(address) + self.deactivate(address) return None def fetch_routing_table(self, address): @@ -365,6 +368,12 @@ def update_routing_table(self): # None of the routers have been successful, so just fail raise ServiceUnavailable("Unable to retrieve routing information") + def update_connection_pool(self): + servers = self.routing_table.servers() + for address in list(self.connections): + if address not in servers: + super(RoutingConnectionPool, self).deactivate(address) + def ensure_routing_table_is_fresh(self, access_mode): """ Update the routing table if stale. @@ -387,6 +396,7 @@ def ensure_routing_table_is_fresh(self, access_mode): self.missing_writer = not self.routing_table.is_fresh(WRITE_ACCESS) return False self.update_routing_table() + self.update_connection_pool() return True def acquire(self, access_mode=None): @@ -410,21 +420,22 @@ def acquire(self, access_mode=None): connection = self.acquire_direct(address) # should always be a resolved address connection.Error = SessionExpired except ServiceUnavailable: - self.remove(address) + self.deactivate(address) else: return connection raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) - def remove(self, address): - """ Remove an address from the connection pool, if present, closing - all connections to that address. Also remove from the routing table. + def deactivate(self, address): + """ Deactivate an address from the connection pool, + if present, remove from the routing table and also closing + all idle connections to that address. """ # We use `discard` instead of `remove` here since the former # will not fail if the address has already been removed. self.routing_table.routers.discard(address) self.routing_table.readers.discard(address) self.routing_table.writers.discard(address) - super(RoutingConnectionPool, self).remove(address) + super(RoutingConnectionPool, self).deactivate(address) def remove_writer(self, address): """ Remove a writer address from the routing table, if present. diff --git a/test/stub/scripts/router_with_multiple_servers.script b/test/stub/scripts/router_with_multiple_servers.script new file mode 100644 index 000000000..f520dc60e --- /dev/null +++ b/test/stub/scripts/router_with_multiple_servers.script @@ -0,0 +1,8 @@ +!: AUTO INIT +!: AUTO RESET + +C: RUN "CALL dbms.cluster.routing.getServers" {} + PULL_ALL +S: SUCCESS {"fields": ["ttl", "servers"]} + RECORD [300, [{"role":"ROUTE","addresses":["127.0.0.1:9001","127.0.0.1:9002"]},{"role":"READ","addresses":["127.0.0.1:9001","127.0.0.1:9003"]},{"role":"WRITE","addresses":["127.0.0.1:9004"]}]] + SUCCESS {} \ No newline at end of file diff --git a/test/stub/test_routing.py b/test/stub/test_routing.py index 05f2bfca2..64e437b49 100644 --- a/test/stub/test_routing.py +++ b/test/stub/test_routing.py @@ -303,6 +303,59 @@ def test_should_flag_reading_without_writer(self): pool.ensure_routing_table_is_fresh(READ_ACCESS) assert pool.missing_writer + def test_should_purge_idle_connections_from_connection_pool(self): + with StubCluster({9006: "router.script", 9001: "router_with_multiple_servers.script"}): + address = ("127.0.0.1", 9006) + with RoutingPool(address) as pool: + # close the acquired connection with init router and then set it to be idle + conn = pool.acquire(WRITE_ACCESS) + conn.close() + conn.in_use = False + + table = pool.routing_table + assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), + ("127.0.0.1", 9003)} + assert table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)} + assert table.writers == {("127.0.0.1", 9006)} + assert set(pool.connections.keys()) == {("127.0.0.1", 9006)} + + # immediately expire the routing table to enforce update a new routing table + pool.routing_table.ttl = 0 + pool.ensure_routing_table_is_fresh(WRITE_ACCESS) + table = pool.routing_table + assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002)} + assert table.readers == {("127.0.0.1", 9001), ("127.0.0.1", 9003)} + assert table.writers == {("127.0.0.1", 9004)} + + assert set(pool.connections.keys()) == {("127.0.0.1", 9001)} + + def test_should_not_purge_idle_connections_from_connection_pool(self): + with StubCluster({9006: "router.script", 9001: "router_with_multiple_servers.script"}): + address = ("127.0.0.1", 9006) + with RoutingPool(address) as pool: + # close the acquired connection with init router and then set it to be inUse + conn = pool.acquire(WRITE_ACCESS) + conn.close() + conn.in_use = True + + table = pool.routing_table + assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), + ("127.0.0.1", 9003)} + assert table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)} + assert table.writers == {("127.0.0.1", 9006)} + assert set(pool.connections.keys()) == {("127.0.0.1", 9006)} + + # immediately expire the routing table to enforce update a new routing table + pool.routing_table.ttl = 0 + pool.ensure_routing_table_is_fresh(WRITE_ACCESS) + table = pool.routing_table + assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002)} + assert table.readers == {("127.0.0.1", 9001), ("127.0.0.1", 9003)} + assert table.writers == {("127.0.0.1", 9004)} + + assert set(pool.connections.keys()) == {("127.0.0.1", 9001), ("127.0.0.1", 9006)} + + # TODO: fix flaky test # def test_concurrent_refreshes_should_not_block_if_fresh(self): # address = ("127.0.0.1", 9001) @@ -481,7 +534,7 @@ def test_should_error_to_writer_in_absent_of_reader(self): assert not pool.missing_writer -class RoutingConnectionPoolRemoveTestCase(StubTestCase): +class RoutingConnectionPoolDeactivateTestCase(StubTestCase): def test_should_remove_router_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) @@ -489,7 +542,7 @@ def test_should_remove_router_from_routing_table_if_present(self): pool.ensure_routing_table_is_fresh(WRITE_ACCESS) target = ("127.0.0.1", 9001) assert target in pool.routing_table.routers - pool.remove(target) + pool.deactivate(target) assert target not in pool.routing_table.routers def test_should_remove_reader_from_routing_table_if_present(self): @@ -499,7 +552,7 @@ def test_should_remove_reader_from_routing_table_if_present(self): pool.ensure_routing_table_is_fresh(WRITE_ACCESS) target = ("127.0.0.1", 9004) assert target in pool.routing_table.readers - pool.remove(target) + pool.deactivate(target) assert target not in pool.routing_table.readers def test_should_remove_writer_from_routing_table_if_present(self): @@ -509,7 +562,7 @@ def test_should_remove_writer_from_routing_table_if_present(self): pool.ensure_routing_table_is_fresh(WRITE_ACCESS) target = ("127.0.0.1", 9006) assert target in pool.routing_table.writers - pool.remove(target) + pool.deactivate(target) assert target not in pool.routing_table.writers def test_should_not_fail_if_absent(self): @@ -518,4 +571,4 @@ def test_should_not_fail_if_absent(self): with RoutingPool(address) as pool: pool.ensure_routing_table_is_fresh(WRITE_ACCESS) target = ("127.0.0.1", 9007) - pool.remove(target) + pool.deactivate(target) diff --git a/test/stub/test_routingdriver.py b/test/stub/test_routingdriver.py index ab2163ccc..8457b8c65 100644 --- a/test/stub/test_routingdriver.py +++ b/test/stub/test_routingdriver.py @@ -289,13 +289,19 @@ def test_forgets_address_on_service_unavailable_error(self): pool = driver._pool table = pool.routing_table - # address should not have connections in the pool, it has failed - assert ('127.0.0.1', 9004) not in pool.connections + # address should have connections in the pool but be inactive, it has failed + assert ('127.0.0.1', 9004) in pool.connections + conns = pool.connections[('127.0.0.1', 9004)] + conn = conns[0] + assert conn._closed == True + assert conn.in_use == True assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} # reader 127.0.0.1:9004 should've been forgotten because of an error assert table.readers == {('127.0.0.1', 9005)} assert table.writers == {('127.0.0.1', 9006)} + assert conn.in_use == False + def test_forgets_address_on_database_unavailable_error(self): with StubCluster({9001: "router.script", 9004: "database_unavailable.script"}): uri = "bolt+routing://127.0.0.1:9001" diff --git a/test/unit/test_routing.py b/test/unit/test_routing.py index a7d12c4df..8128af16e 100644 --- a/test/unit/test_routing.py +++ b/test/unit/test_routing.py @@ -169,6 +169,20 @@ def test_should_fail_on_multiple_records(self): _ = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD, VALID_ROUTING_RECORD]) +class RoutingTableServersTestCase(TestCase): + def test_should_return_all_distinct_servers_in_routing_table(self): + routing_table = { + "ttl": 300, + "servers": [ + {"role": "ROUTE", "addresses": ["127.0.0.1:9001", "127.0.0.1:9002", "127.0.0.1:9003"]}, + {"role": "READ", "addresses": ["127.0.0.1:9001", "127.0.0.1:9005"]}, + {"role": "WRITE", "addresses": ["127.0.0.1:9002"]}, + ], + } + table = RoutingTable.parse_routing_info([routing_table]) + assert table.servers() == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003), ('127.0.0.1', 9005)} + + class RoutingTableFreshnessTestCase(TestCase): def test_should_be_fresh_after_update(self): table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD])