diff --git a/.gitignore b/.gitignore index 370aa4039..930d9fa69 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ htmlcov .coverage .test .tox +.benchmarks .cache docs/build diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index 9a505f6e5..d35a603ec 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -418,6 +418,17 @@ def release(self, connection): with self.lock: connection.in_use = False + def in_use_connection_count(self, address): + """ Count the number of connections currently in use to a given + address. + """ + try: + connections = self.connections[address] + except KeyError: + return 0 + else: + return sum(1 if connection.in_use else 0 for connection in connections) + def remove(self, address): """ Remove an address from the connection pool, if present, closing all connections to that address. diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 7ecf8c1a7..7d3ffbda0 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -17,8 +17,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +from abc import abstractmethod +from sys import maxsize from threading import Lock from time import clock @@ -26,14 +26,19 @@ from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect from neo4j.compat.collections import MutableSet, OrderedDict from neo4j.exceptions import CypherError +from neo4j.util import ServerVersion from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters from neo4j.v1.exceptions import SessionExpired from neo4j.v1.security import SecurityPlan from neo4j.v1.session import BoltSession -from neo4j.util import ServerVersion -class RoundRobinSet(MutableSet): +LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0 +LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1 +LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED + + +class OrderedSet(MutableSet): def __init__(self, elements=()): self._elements = OrderedDict.fromkeys(elements) @@ -45,22 +50,15 @@ def __repr__(self): def __contains__(self, element): return element in self._elements - def __next__(self): - current = None - if self._elements: - if self._current is None: - self._current = 0 - else: - self._current = (self._current + 1) % len(self._elements) - current = list(self._elements.keys())[self._current] - return current - def __iter__(self): return iter(self._elements) def __len__(self): return len(self._elements) + def __getitem__(self, index): + return list(self._elements.keys())[index] + def add(self, element): self._elements[element] = None @@ -73,9 +71,6 @@ def discard(self, element): except KeyError: pass - def next(self): - return self.__next__() - def remove(self, element): try: del self._elements[element] @@ -126,9 +121,9 @@ def parse_routing_info(cls, records): return cls(routers, readers, writers, ttl) def __init__(self, routers=(), readers=(), writers=(), ttl=0): - self.routers = RoundRobinSet(routers) - self.readers = RoundRobinSet(readers) - self.writers = RoundRobinSet(writers) + self.routers = OrderedSet(routers) + self.readers = OrderedSet(readers) + self.writers = OrderedSet(writers) self.last_updated_time = self.timer() self.ttl = ttl @@ -168,17 +163,102 @@ def __run__(self, ignored, routing_context): return self._run(fix_statement(statement), fix_parameters(parameters)) +class LoadBalancingStrategy(object): + + @classmethod + def build(cls, connection_pool, **config): + load_balancing_strategy = config.get("load_balancing_strategy", LOAD_BALANCING_STRATEGY_DEFAULT) + if load_balancing_strategy == LOAD_BALANCING_STRATEGY_LEAST_CONNECTED: + return LeastConnectedLoadBalancingStrategy(connection_pool) + elif load_balancing_strategy == LOAD_BALANCING_STRATEGY_ROUND_ROBIN: + return RoundRobinLoadBalancingStrategy() + else: + raise ValueError("Unknown load balancing strategy '%s'" % load_balancing_strategy) + + @abstractmethod + def select_reader(self, known_readers): + raise NotImplementedError() + + @abstractmethod + def select_writer(self, known_writers): + raise NotImplementedError() + + +class RoundRobinLoadBalancingStrategy(LoadBalancingStrategy): + + _readers_offset = 0 + _writers_offset = 0 + + def select_reader(self, known_readers): + address = self._select(self._readers_offset, known_readers) + self._readers_offset += 1 + return address + + def select_writer(self, known_writers): + address = self._select(self._writers_offset, known_writers) + self._writers_offset += 1 + return address + + @classmethod + def _select(cls, offset, addresses): + if not addresses: + return None + return addresses[offset % len(addresses)] + + +class LeastConnectedLoadBalancingStrategy(LoadBalancingStrategy): + + def __init__(self, connection_pool): + self._readers_offset = 0 + self._writers_offset = 0 + self._connection_pool = connection_pool + + def select_reader(self, known_readers): + address = self._select(self._readers_offset, known_readers) + self._readers_offset += 1 + return address + + def select_writer(self, known_writers): + address = self._select(self._writers_offset, known_writers) + self._writers_offset += 1 + return address + + def _select(self, offset, addresses): + if not addresses: + return None + num_addresses = len(addresses) + start_index = offset % num_addresses + index = start_index + + least_connected_address = None + least_in_use_connections = maxsize + + while True: + address = addresses[index] + index = (index + 1) % num_addresses + + in_use_connections = self._connection_pool.in_use_connection_count(address) + + if in_use_connections < least_in_use_connections: + least_connected_address = address + least_in_use_connections = in_use_connections + + if index == start_index: + return least_connected_address + + class RoutingConnectionPool(ConnectionPool): """ Connection pool with routing table. """ - def __init__(self, connector, initial_address, routing_context, *routers): + def __init__(self, connector, initial_address, routing_context, *routers, **config): super(RoutingConnectionPool, self).__init__(connector) self.initial_address = initial_address self.routing_context = routing_context self.routing_table = RoutingTable(routers) self.missing_writer = False self.refresh_lock = Lock() + self.load_balancing_strategy = LoadBalancingStrategy.build(self, **config) def fetch_routing_info(self, address): """ Fetch raw routing info from a given router address. @@ -304,14 +384,16 @@ def acquire(self, access_mode=None): access_mode = WRITE_ACCESS if access_mode == READ_ACCESS: server_list = self.routing_table.readers + server_selector = self.load_balancing_strategy.select_reader elif access_mode == WRITE_ACCESS: server_list = self.routing_table.writers + server_selector = self.load_balancing_strategy.select_writer else: raise ValueError("Unsupported access mode {}".format(access_mode)) self.ensure_routing_table_is_fresh(access_mode) while True: - address = next(server_list) + address = server_selector(server_list) if address is None: break try: @@ -354,7 +436,7 @@ def __init__(self, uri, **config): def connector(a): return connect(a, security_plan.ssl_context, **config) - pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address)) + pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address), **config) try: pool.update_routing_table() except: diff --git a/test/integration/test_connection.py b/test/integration/test_connection.py index 897b6f47e..1a28c2ff9 100644 --- a/test/integration/test_connection.py +++ b/test/integration/test_connection.py @@ -108,3 +108,11 @@ def test_cannot_acquire_after_close(self): pool.close() with self.assertRaises(ServiceUnavailable): _ = pool.acquire_direct("X") + + def test_in_use_count(self): + address = ("127.0.0.1", 7687) + self.assertEqual(self.pool.in_use_connection_count(address), 0) + connection = self.pool.acquire_direct(address) + self.assertEqual(self.pool.in_use_connection_count(address), 1) + self.pool.release(connection) + self.assertEqual(self.pool.in_use_connection_count(address), 0) diff --git a/test/stub/test_routingdriver.py b/test/stub/test_routingdriver.py index debe7050a..0edefe666 100644 --- a/test/stub/test_routingdriver.py +++ b/test/stub/test_routingdriver.py @@ -19,7 +19,9 @@ # limitations under the License. -from neo4j.v1 import GraphDatabase, RoutingDriver, READ_ACCESS, WRITE_ACCESS, SessionExpired +from neo4j.v1 import GraphDatabase, READ_ACCESS, WRITE_ACCESS, SessionExpired, \ + RoutingDriver, RoutingConnectionPool, LeastConnectedLoadBalancingStrategy, LOAD_BALANCING_STRATEGY_ROUND_ROBIN, \ + RoundRobinLoadBalancingStrategy from neo4j.bolt import ProtocolError, ServiceUnavailable from test.stub.tools import StubTestCase, StubCluster @@ -215,3 +217,27 @@ def test_should_error_when_missing_reader(self): uri = "bolt+routing://127.0.0.1:9001" with self.assertRaises(ProtocolError): GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) + + def test_default_load_balancing_strategy_is_least_connected(self): + with StubCluster({9001: "router.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + self.assertIsInstance(driver, RoutingDriver) + self.assertIsInstance(driver._pool, RoutingConnectionPool) + self.assertIsInstance(driver._pool.load_balancing_strategy, LeastConnectedLoadBalancingStrategy) + + def test_can_select_round_robin_load_balancing_strategy(self): + with StubCluster({9001: "router.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False, + load_balancing_strategy=LOAD_BALANCING_STRATEGY_ROUND_ROBIN) as driver: + self.assertIsInstance(driver, RoutingDriver) + self.assertIsInstance(driver._pool, RoutingConnectionPool) + self.assertIsInstance(driver._pool.load_balancing_strategy, RoundRobinLoadBalancingStrategy) + + def test_no_other_load_balancing_strategies_are_available(self): + with StubCluster({9001: "router.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with self.assertRaises(ValueError): + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False, load_balancing_strategy=-1): + pass diff --git a/test/unit/test_routing.py b/test/unit/test_routing.py index 2b9fce5ce..92ca5507d 100644 --- a/test/unit/test_routing.py +++ b/test/unit/test_routing.py @@ -17,14 +17,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from collections import OrderedDict from unittest import TestCase from neo4j.bolt import ProtocolError from neo4j.bolt.connection import connect -from neo4j.v1.routing import RoundRobinSet, RoutingTable, RoutingConnectionPool +from neo4j.v1.routing import OrderedSet, RoutingTable, RoutingConnectionPool, LeastConnectedLoadBalancingStrategy, \ + RoundRobinLoadBalancingStrategy from neo4j.v1.security import basic_auth -from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS +from neo4j.v1.api import READ_ACCESS, WRITE_ACCESS VALID_ROUTING_RECORD = { @@ -56,93 +57,84 @@ def connector(address): class RoundRobinSetTestCase(TestCase): - def test_should_repr_as_set(self): - rrs = RoundRobinSet([1, 2, 3]) - assert repr(rrs) == "{1, 2, 3}" + s = OrderedSet([1, 2, 3]) + assert repr(s) == "{1, 2, 3}" def test_should_contain_element(self): - rrs = RoundRobinSet([1, 2, 3]) - assert 2 in rrs + s = OrderedSet([1, 2, 3]) + assert 2 in s def test_should_not_contain_non_element(self): - rrs = RoundRobinSet([1, 2, 3]) - assert 4 not in rrs - - def test_should_be_able_to_get_next_if_empty(self): - rrs = RoundRobinSet([]) - assert next(rrs) is None - - def test_should_be_able_to_get_next_repeatedly(self): - rrs = RoundRobinSet([1, 2, 3]) - assert next(rrs) == 1 - assert next(rrs) == 2 - assert next(rrs) == 3 - assert next(rrs) == 1 - - def test_should_be_able_to_get_next_repeatedly_via_old_method(self): - rrs = RoundRobinSet([1, 2, 3]) - assert rrs.next() == 1 - assert rrs.next() == 2 - assert rrs.next() == 3 - assert rrs.next() == 1 + s = OrderedSet([1, 2, 3]) + assert 4 not in s + + def test_should_be_able_to_get_item_if_empty(self): + s = OrderedSet([]) + with self.assertRaises(IndexError): + _ = s[0] + + def test_should_be_able_to_get_items_by_index(self): + s = OrderedSet([1, 2, 3]) + self.assertEqual(s[0], 1) + self.assertEqual(s[1], 2) + self.assertEqual(s[2], 3) def test_should_be_iterable(self): - rrs = RoundRobinSet([1, 2, 3]) - assert list(iter(rrs)) == [1, 2, 3] + s = OrderedSet([1, 2, 3]) + assert list(iter(s)) == [1, 2, 3] def test_should_have_length(self): - rrs = RoundRobinSet([1, 2, 3]) - assert len(rrs) == 3 + s = OrderedSet([1, 2, 3]) + assert len(s) == 3 def test_should_be_able_to_add_new(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.add(4) - assert list(rrs) == [1, 2, 3, 4] + s = OrderedSet([1, 2, 3]) + s.add(4) + assert list(s) == [1, 2, 3, 4] def test_should_be_able_to_add_existing(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.add(2) - assert list(rrs) == [1, 2, 3] + s = OrderedSet([1, 2, 3]) + s.add(2) + assert list(s) == [1, 2, 3] def test_should_be_able_to_clear(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.clear() - assert list(rrs) == [] + s = OrderedSet([1, 2, 3]) + s.clear() + assert list(s) == [] def test_should_be_able_to_discard_existing(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.discard(2) - assert list(rrs) == [1, 3] + s = OrderedSet([1, 2, 3]) + s.discard(2) + assert list(s) == [1, 3] def test_should_be_able_to_discard_non_existing(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.discard(4) - assert list(rrs) == [1, 2, 3] + s = OrderedSet([1, 2, 3]) + s.discard(4) + assert list(s) == [1, 2, 3] def test_should_be_able_to_remove_existing(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.remove(2) - assert list(rrs) == [1, 3] + s = OrderedSet([1, 2, 3]) + s.remove(2) + assert list(s) == [1, 3] def test_should_not_be_able_to_remove_non_existing(self): - rrs = RoundRobinSet([1, 2, 3]) + s = OrderedSet([1, 2, 3]) with self.assertRaises(ValueError): - rrs.remove(4) + s.remove(4) def test_should_be_able_to_update(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.update([3, 4, 5]) - assert list(rrs) == [1, 2, 3, 4, 5] + s = OrderedSet([1, 2, 3]) + s.update([3, 4, 5]) + assert list(s) == [1, 2, 3, 4, 5] def test_should_be_able_to_replace(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.replace([3, 4, 5]) - assert list(rrs) == [3, 4, 5] + s = OrderedSet([1, 2, 3]) + s.replace([3, 4, 5]) + assert list(s) == [3, 4, 5] class RoutingTableConstructionTestCase(TestCase): - def test_should_be_initially_stale(self): table = RoutingTable() assert not table.is_fresh(READ_ACCESS) @@ -150,7 +142,6 @@ def test_should_be_initially_stale(self): class RoutingTableParseRoutingInfoTestCase(TestCase): - def test_should_return_routing_table_on_valid_record(self): table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} @@ -179,7 +170,6 @@ def test_should_fail_on_multiple_records(self): class RoutingTableFreshnessTestCase(TestCase): - def test_should_be_fresh_after_update(self): table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) assert table.is_fresh(READ_ACCESS) @@ -205,7 +195,6 @@ def test_should_become_stale_if_no_writers(self): class RoutingTableUpdateTestCase(TestCase): - def setUp(self): self.table = RoutingTable( [("192.168.1.1", 7687), ("192.168.1.2", 7687)], [("192.168.1.3", 7687)], [], 0) @@ -231,9 +220,119 @@ def test_update_should_replace_ttl(self): class RoutingConnectionPoolConstructionTestCase(TestCase): - def test_should_populate_initial_router(self): initial_router = ("127.0.0.1", 9001) router = ("127.0.0.1", 9002) with RoutingConnectionPool(connector, initial_router, {}, router) as pool: assert pool.routing_table.routers == {("127.0.0.1", 9002)} + + +class FakeConnectionPool(object): + + def __init__(self, addresses): + self._addresses = addresses + + def in_use_connection_count(self, address): + return self._addresses.get(address, 0) + + +class RoundRobinLoadBalancingStrategyTestCase(TestCase): + + def test_simple_reader_selection(self): + strategy = RoundRobinLoadBalancingStrategy() + self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "0.0.0.0") + self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "1.1.1.1") + self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "2.2.2.2") + self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "0.0.0.0") + + def test_empty_reader_selection(self): + strategy = RoundRobinLoadBalancingStrategy() + self.assertIsNone(strategy.select_reader([])) + + def test_simple_writer_selection(self): + strategy = RoundRobinLoadBalancingStrategy() + self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "0.0.0.0") + self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "1.1.1.1") + self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "2.2.2.2") + self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "0.0.0.0") + + def test_empty_writer_selection(self): + strategy = RoundRobinLoadBalancingStrategy() + self.assertIsNone(strategy.select_writer([])) + + +class LeastConnectedLoadBalancingStrategyTestCase(TestCase): + + def test_simple_reader_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("0.0.0.0", 2), + ("1.1.1.1", 1), + ("2.2.2.2", 0), + ]))) + self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "2.2.2.2") + + def test_reader_selection_with_clash(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("0.0.0.0", 0), + ("0.0.0.1", 0), + ("1.1.1.1", 1), + ]))) + self.assertEqual(strategy.select_reader(["0.0.0.0", "0.0.0.1", "1.1.1.1"]), "0.0.0.0") + self.assertEqual(strategy.select_reader(["0.0.0.0", "0.0.0.1", "1.1.1.1"]), "0.0.0.1") + + def test_empty_reader_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ]))) + self.assertIsNone(strategy.select_reader([])) + + def test_not_in_pool_reader_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("1.1.1.1", 1), + ("2.2.2.2", 2), + ]))) + self.assertEqual(strategy.select_reader(["2.2.2.2", "3.3.3.3"]), "3.3.3.3") + + def test_partially_in_pool_reader_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("1.1.1.1", 1), + ("2.2.2.2", 0), + ]))) + self.assertEqual(strategy.select_reader(["2.2.2.2", "3.3.3.3"]), "2.2.2.2") + self.assertEqual(strategy.select_reader(["2.2.2.2", "3.3.3.3"]), "3.3.3.3") + + def test_simple_writer_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("0.0.0.0", 2), + ("1.1.1.1", 1), + ("2.2.2.2", 0), + ]))) + self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "2.2.2.2") + + def test_writer_selection_with_clash(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("0.0.0.0", 0), + ("0.0.0.1", 0), + ("1.1.1.1", 1), + ]))) + self.assertEqual(strategy.select_writer(["0.0.0.0", "0.0.0.1", "1.1.1.1"]), "0.0.0.0") + self.assertEqual(strategy.select_writer(["0.0.0.0", "0.0.0.1", "1.1.1.1"]), "0.0.0.1") + + def test_empty_writer_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ]))) + self.assertIsNone(strategy.select_writer([])) + + def test_not_in_pool_writer_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("1.1.1.1", 1), + ("2.2.2.2", 2), + ]))) + self.assertEqual(strategy.select_writer(["2.2.2.2", "3.3.3.3"]), "3.3.3.3") + + def test_partially_in_pool_writer_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("1.1.1.1", 1), + ("2.2.2.2", 0), + ]))) + self.assertEqual(strategy.select_writer(["2.2.2.2", "3.3.3.3"]), "2.2.2.2") + self.assertEqual(strategy.select_writer(["2.2.2.2", "3.3.3.3"]), "3.3.3.3")