From c660463a46865b8bcc765e3613637e75aa417f55 Mon Sep 17 00:00:00 2001 From: lutovich Date: Thu, 13 Jul 2017 11:58:50 +0200 Subject: [PATCH 1/7] Initial impl of least connected --- neo4j/bolt/connection.py | 14 +++++ neo4j/v1/routing.py | 113 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 122 insertions(+), 5 deletions(-) diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index 9a505f6e5..de4acbc5c 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -418,6 +418,20 @@ def release(self, connection): with self.lock: connection.in_use = False + def in_use_connection_count(self, address): + try: + connections = self.connections[address] + except KeyError: + return 0 + else: + in_use_count = 0 + + for connection in list(connections): + if connection.in_use: + in_use_count += 1 + + return in_use_count + def remove(self, address): """ Remove an address from the connection pool, if present, closing all connections to that address. diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 7ecf8c1a7..a625cbba5 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -19,6 +19,7 @@ # limitations under the License. +from sys import maxsize from threading import Lock from time import clock @@ -26,11 +27,16 @@ from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect from neo4j.compat.collections import MutableSet, OrderedDict from neo4j.exceptions import CypherError +from neo4j.util import ServerVersion from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters from neo4j.v1.exceptions import SessionExpired from neo4j.v1.security import SecurityPlan from neo4j.v1.session import BoltSession -from neo4j.util import ServerVersion + + +LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0 +LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1 +LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED class RoundRobinSet(MutableSet): @@ -52,7 +58,7 @@ def __next__(self): self._current = 0 else: self._current = (self._current + 1) % len(self._elements) - current = list(self._elements.keys())[self._current] + current = self.get(self._current) return current def __iter__(self): @@ -90,6 +96,9 @@ def replace(self, elements=()): e.clear() e.update(OrderedDict.fromkeys(elements)) + def get(self, index): + return list(self._elements.keys())[index] + class RoutingTable(object): @@ -168,17 +177,109 @@ def __run__(self, ignored, routing_context): return self._run(fix_statement(statement), fix_parameters(parameters)) +class LoadBalancingStrategy(object): + + @classmethod + def build(cls, connection_pool, **config): + load_balancing_strategy = config.get("load_balancing_strategy", LOAD_BALANCING_STRATEGY_DEFAULT) + if load_balancing_strategy == LOAD_BALANCING_STRATEGY_LEAST_CONNECTED: + return LeastConnectedLoadBalancingStrategy(connection_pool) + elif load_balancing_strategy == LOAD_BALANCING_STRATEGY_ROUND_ROBIN: + return RoundRobinLoadBalancingStrategy() + else: + raise ValueError("Unknown load balancing strategy '%s'" % load_balancing_strategy) + pass + + def select_reader(self, known_readers): + raise NotImplementedError() + + def select_writer(self, known_writers): + raise NotImplementedError() + + +class RoundRobinLoadBalancingStrategy(LoadBalancingStrategy): + + _readers_offset = 0 + _writers_offset = 0 + + def select_reader(self, known_readers): + address = self.select(self._readers_offset, known_readers) + self._readers_offset += 1 + return address + + def select_writer(self, known_writers): + address = self.select(self._writers_offset, known_writers) + self._writers_offset += 1 + return address + + def select(self, offset, addresses): + length = len(addresses) + if length == 0: + return None + else: + index = offset % length + return addresses.get(index) + + +class LeastConnectedLoadBalancingStrategy(LoadBalancingStrategy): + + def __init__(self, connection_pool): + self._readers_offset = 0 + self._writers_offset = 0 + self._connection_pool = connection_pool + + def select_reader(self, known_readers): + address = self.select(self._readers_offset, known_readers) + self._readers_offset += 1 + return address + + def select_writer(self, known_writers): + address = self.select(self._writers_offset, known_writers) + self._writers_offset += 1 + return address + + def select(self, offset, addresses): + length = len(addresses) + if length == 0: + return None + else: + start_index = offset % length + index = start_index + + least_connected_address = None + least_in_use_connections = maxsize + + while True: + address = addresses.get(index) + in_use_connections = self._connection_pool.in_use_connection_count(address) + + if in_use_connections < least_in_use_connections: + least_connected_address = address + least_in_use_connections = in_use_connections + + if index == length - 1: + index = 0 + else: + index += 1 + + if index == start_index: + break + + return least_connected_address + + class RoutingConnectionPool(ConnectionPool): """ Connection pool with routing table. """ - def __init__(self, connector, initial_address, routing_context, *routers): + def __init__(self, connector, initial_address, routing_context, *routers, **config): super(RoutingConnectionPool, self).__init__(connector) self.initial_address = initial_address self.routing_context = routing_context self.routing_table = RoutingTable(routers) self.missing_writer = False self.refresh_lock = Lock() + self.load_balancing_strategy = LoadBalancingStrategy.build(self, **config) def fetch_routing_info(self, address): """ Fetch raw routing info from a given router address. @@ -304,14 +405,16 @@ def acquire(self, access_mode=None): access_mode = WRITE_ACCESS if access_mode == READ_ACCESS: server_list = self.routing_table.readers + server_selector = self.load_balancing_strategy.select_reader elif access_mode == WRITE_ACCESS: server_list = self.routing_table.writers + server_selector = self.load_balancing_strategy.select_writer else: raise ValueError("Unsupported access mode {}".format(access_mode)) self.ensure_routing_table_is_fresh(access_mode) while True: - address = next(server_list) + address = server_selector(server_list) if address is None: break try: @@ -354,7 +457,7 @@ def __init__(self, uri, **config): def connector(a): return connect(a, security_plan.ssl_context, **config) - pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address)) + pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address), **config) try: pool.update_routing_table() except: From f1c1c83b15554f82e433477ae1c8ee0654a95405 Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Fri, 14 Jul 2017 11:35:40 +0100 Subject: [PATCH 2/7] Test in use count --- neo4j/bolt/connection.py | 11 ++++------- test/integration/test_connection.py | 8 ++++++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index de4acbc5c..d35a603ec 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -419,18 +419,15 @@ def release(self, connection): connection.in_use = False def in_use_connection_count(self, address): + """ Count the number of connections currently in use to a given + address. + """ try: connections = self.connections[address] except KeyError: return 0 else: - in_use_count = 0 - - for connection in list(connections): - if connection.in_use: - in_use_count += 1 - - return in_use_count + return sum(1 if connection.in_use else 0 for connection in connections) def remove(self, address): """ Remove an address from the connection pool, if present, closing diff --git a/test/integration/test_connection.py b/test/integration/test_connection.py index 897b6f47e..1a28c2ff9 100644 --- a/test/integration/test_connection.py +++ b/test/integration/test_connection.py @@ -108,3 +108,11 @@ def test_cannot_acquire_after_close(self): pool.close() with self.assertRaises(ServiceUnavailable): _ = pool.acquire_direct("X") + + def test_in_use_count(self): + address = ("127.0.0.1", 7687) + self.assertEqual(self.pool.in_use_connection_count(address), 0) + connection = self.pool.acquire_direct(address) + self.assertEqual(self.pool.in_use_connection_count(address), 1) + self.pool.release(connection) + self.assertEqual(self.pool.in_use_connection_count(address), 0) From ec5aceb3d3f1157313c616d56ba71f018af4b46c Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Fri, 14 Jul 2017 12:00:27 +0100 Subject: [PATCH 3/7] RoundRobinSet->OrderedSet plus tests --- neo4j/v1/routing.py | 29 +++-------- test/unit/test_routing.py | 103 ++++++++++++++++++-------------------- 2 files changed, 56 insertions(+), 76 deletions(-) diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index a625cbba5..aa84d9a28 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -39,7 +39,7 @@ LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED -class RoundRobinSet(MutableSet): +class OrderedSet(MutableSet): def __init__(self, elements=()): self._elements = OrderedDict.fromkeys(elements) @@ -51,22 +51,15 @@ def __repr__(self): def __contains__(self, element): return element in self._elements - def __next__(self): - current = None - if self._elements: - if self._current is None: - self._current = 0 - else: - self._current = (self._current + 1) % len(self._elements) - current = self.get(self._current) - return current - def __iter__(self): return iter(self._elements) def __len__(self): return len(self._elements) + def __getitem__(self, index): + return list(self._elements.keys())[index] + def add(self, element): self._elements[element] = None @@ -79,9 +72,6 @@ def discard(self, element): except KeyError: pass - def next(self): - return self.__next__() - def remove(self, element): try: del self._elements[element] @@ -96,9 +86,6 @@ def replace(self, elements=()): e.clear() e.update(OrderedDict.fromkeys(elements)) - def get(self, index): - return list(self._elements.keys())[index] - class RoutingTable(object): @@ -135,9 +122,9 @@ def parse_routing_info(cls, records): return cls(routers, readers, writers, ttl) def __init__(self, routers=(), readers=(), writers=(), ttl=0): - self.routers = RoundRobinSet(routers) - self.readers = RoundRobinSet(readers) - self.writers = RoundRobinSet(writers) + self.routers = OrderedSet(routers) + self.readers = OrderedSet(readers) + self.writers = OrderedSet(writers) self.last_updated_time = self.timer() self.ttl = ttl @@ -250,7 +237,7 @@ def select(self, offset, addresses): least_in_use_connections = maxsize while True: - address = addresses.get(index) + address = addresses[index] in_use_connections = self._connection_pool.in_use_connection_count(address) if in_use_connections < least_in_use_connections: diff --git a/test/unit/test_routing.py b/test/unit/test_routing.py index 2b9fce5ce..24fdc6505 100644 --- a/test/unit/test_routing.py +++ b/test/unit/test_routing.py @@ -22,7 +22,7 @@ from neo4j.bolt import ProtocolError from neo4j.bolt.connection import connect -from neo4j.v1.routing import RoundRobinSet, RoutingTable, RoutingConnectionPool +from neo4j.v1.routing import OrderedSet, RoutingTable, RoutingConnectionPool from neo4j.v1.security import basic_auth from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS @@ -58,87 +58,80 @@ def connector(address): class RoundRobinSetTestCase(TestCase): def test_should_repr_as_set(self): - rrs = RoundRobinSet([1, 2, 3]) - assert repr(rrs) == "{1, 2, 3}" + s = OrderedSet([1, 2, 3]) + assert repr(s) == "{1, 2, 3}" def test_should_contain_element(self): - rrs = RoundRobinSet([1, 2, 3]) - assert 2 in rrs + s = OrderedSet([1, 2, 3]) + assert 2 in s def test_should_not_contain_non_element(self): - rrs = RoundRobinSet([1, 2, 3]) - assert 4 not in rrs - - def test_should_be_able_to_get_next_if_empty(self): - rrs = RoundRobinSet([]) - assert next(rrs) is None - - def test_should_be_able_to_get_next_repeatedly(self): - rrs = RoundRobinSet([1, 2, 3]) - assert next(rrs) == 1 - assert next(rrs) == 2 - assert next(rrs) == 3 - assert next(rrs) == 1 - - def test_should_be_able_to_get_next_repeatedly_via_old_method(self): - rrs = RoundRobinSet([1, 2, 3]) - assert rrs.next() == 1 - assert rrs.next() == 2 - assert rrs.next() == 3 - assert rrs.next() == 1 + s = OrderedSet([1, 2, 3]) + assert 4 not in s + + def test_should_be_able_to_get_item_if_empty(self): + s = OrderedSet([]) + with self.assertRaises(IndexError): + _ = s[0] + + def test_should_be_able_to_get_items_by_index(self): + s = OrderedSet([1, 2, 3]) + self.assertEqual(s[0], 1) + self.assertEqual(s[1], 2) + self.assertEqual(s[2], 3) def test_should_be_iterable(self): - rrs = RoundRobinSet([1, 2, 3]) - assert list(iter(rrs)) == [1, 2, 3] + s = OrderedSet([1, 2, 3]) + assert list(iter(s)) == [1, 2, 3] def test_should_have_length(self): - rrs = RoundRobinSet([1, 2, 3]) - assert len(rrs) == 3 + s = OrderedSet([1, 2, 3]) + assert len(s) == 3 def test_should_be_able_to_add_new(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.add(4) - assert list(rrs) == [1, 2, 3, 4] + s = OrderedSet([1, 2, 3]) + s.add(4) + assert list(s) == [1, 2, 3, 4] def test_should_be_able_to_add_existing(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.add(2) - assert list(rrs) == [1, 2, 3] + s = OrderedSet([1, 2, 3]) + s.add(2) + assert list(s) == [1, 2, 3] def test_should_be_able_to_clear(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.clear() - assert list(rrs) == [] + s = OrderedSet([1, 2, 3]) + s.clear() + assert list(s) == [] def test_should_be_able_to_discard_existing(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.discard(2) - assert list(rrs) == [1, 3] + s = OrderedSet([1, 2, 3]) + s.discard(2) + assert list(s) == [1, 3] def test_should_be_able_to_discard_non_existing(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.discard(4) - assert list(rrs) == [1, 2, 3] + s = OrderedSet([1, 2, 3]) + s.discard(4) + assert list(s) == [1, 2, 3] def test_should_be_able_to_remove_existing(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.remove(2) - assert list(rrs) == [1, 3] + s = OrderedSet([1, 2, 3]) + s.remove(2) + assert list(s) == [1, 3] def test_should_not_be_able_to_remove_non_existing(self): - rrs = RoundRobinSet([1, 2, 3]) + s = OrderedSet([1, 2, 3]) with self.assertRaises(ValueError): - rrs.remove(4) + s.remove(4) def test_should_be_able_to_update(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.update([3, 4, 5]) - assert list(rrs) == [1, 2, 3, 4, 5] + s = OrderedSet([1, 2, 3]) + s.update([3, 4, 5]) + assert list(s) == [1, 2, 3, 4, 5] def test_should_be_able_to_replace(self): - rrs = RoundRobinSet([1, 2, 3]) - rrs.replace([3, 4, 5]) - assert list(rrs) == [3, 4, 5] + s = OrderedSet([1, 2, 3]) + s.replace([3, 4, 5]) + assert list(s) == [3, 4, 5] class RoutingTableConstructionTestCase(TestCase): From 7f5ab21cb8e50e6268498ab61d6ab6fb7c4ca274 Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Fri, 14 Jul 2017 13:34:21 +0100 Subject: [PATCH 4/7] Added .benchmarks to .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 370aa4039..930d9fa69 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ htmlcov .coverage .test .tox +.benchmarks .cache docs/build From 47bd1361145c46a711143675b3b9a9fbb8db04dd Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Fri, 14 Jul 2017 15:47:25 +0100 Subject: [PATCH 5/7] Tests for routing strategies --- neo4j/v1/routing.py | 65 ++++++++++---------- test/unit/test_routing.py | 124 +++++++++++++++++++++++++++++++++++--- 2 files changed, 146 insertions(+), 43 deletions(-) diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index aa84d9a28..7b37aeff3 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -17,8 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +from abc import abstractmethod from sys import maxsize from threading import Lock from time import clock @@ -175,11 +174,12 @@ def build(cls, connection_pool, **config): return RoundRobinLoadBalancingStrategy() else: raise ValueError("Unknown load balancing strategy '%s'" % load_balancing_strategy) - pass + @abstractmethod def select_reader(self, known_readers): raise NotImplementedError() + @abstractmethod def select_writer(self, known_writers): raise NotImplementedError() @@ -190,22 +190,20 @@ class RoundRobinLoadBalancingStrategy(LoadBalancingStrategy): _writers_offset = 0 def select_reader(self, known_readers): - address = self.select(self._readers_offset, known_readers) + address = self._select(self._readers_offset, known_readers) self._readers_offset += 1 return address def select_writer(self, known_writers): - address = self.select(self._writers_offset, known_writers) + address = self._select(self._writers_offset, known_writers) self._writers_offset += 1 return address - def select(self, offset, addresses): - length = len(addresses) - if length == 0: + @classmethod + def _select(cls, offset, addresses): + if not addresses: return None - else: - index = offset % length - return addresses.get(index) + return addresses[offset % len(addresses)] class LeastConnectedLoadBalancingStrategy(LoadBalancingStrategy): @@ -216,43 +214,42 @@ def __init__(self, connection_pool): self._connection_pool = connection_pool def select_reader(self, known_readers): - address = self.select(self._readers_offset, known_readers) + address = self._select(self._readers_offset, known_readers) self._readers_offset += 1 return address def select_writer(self, known_writers): - address = self.select(self._writers_offset, known_writers) + address = self._select(self._writers_offset, known_writers) self._writers_offset += 1 return address - def select(self, offset, addresses): - length = len(addresses) - if length == 0: + def _select(self, offset, addresses): + if not addresses: return None - else: - start_index = offset % length - index = start_index + num_addresses = len(addresses) + start_index = offset % num_addresses + index = start_index - least_connected_address = None - least_in_use_connections = maxsize + least_connected_address = None + least_in_use_connections = maxsize - while True: - address = addresses[index] - in_use_connections = self._connection_pool.in_use_connection_count(address) + while True: + address = addresses[index] + in_use_connections = self._connection_pool.in_use_connection_count(address) - if in_use_connections < least_in_use_connections: - least_connected_address = address - least_in_use_connections = in_use_connections + if in_use_connections < least_in_use_connections: + least_connected_address = address + least_in_use_connections = in_use_connections - if index == length - 1: - index = 0 - else: - index += 1 + if index == num_addresses - 1: + index = 0 + else: + index += 1 - if index == start_index: - break + if index == start_index: + break - return least_connected_address + return least_connected_address class RoutingConnectionPool(ConnectionPool): diff --git a/test/unit/test_routing.py b/test/unit/test_routing.py index 24fdc6505..92ca5507d 100644 --- a/test/unit/test_routing.py +++ b/test/unit/test_routing.py @@ -17,14 +17,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from collections import OrderedDict from unittest import TestCase from neo4j.bolt import ProtocolError from neo4j.bolt.connection import connect -from neo4j.v1.routing import OrderedSet, RoutingTable, RoutingConnectionPool +from neo4j.v1.routing import OrderedSet, RoutingTable, RoutingConnectionPool, LeastConnectedLoadBalancingStrategy, \ + RoundRobinLoadBalancingStrategy from neo4j.v1.security import basic_auth -from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS +from neo4j.v1.api import READ_ACCESS, WRITE_ACCESS VALID_ROUTING_RECORD = { @@ -56,7 +57,6 @@ def connector(address): class RoundRobinSetTestCase(TestCase): - def test_should_repr_as_set(self): s = OrderedSet([1, 2, 3]) assert repr(s) == "{1, 2, 3}" @@ -135,7 +135,6 @@ def test_should_be_able_to_replace(self): class RoutingTableConstructionTestCase(TestCase): - def test_should_be_initially_stale(self): table = RoutingTable() assert not table.is_fresh(READ_ACCESS) @@ -143,7 +142,6 @@ def test_should_be_initially_stale(self): class RoutingTableParseRoutingInfoTestCase(TestCase): - def test_should_return_routing_table_on_valid_record(self): table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} @@ -172,7 +170,6 @@ def test_should_fail_on_multiple_records(self): class RoutingTableFreshnessTestCase(TestCase): - def test_should_be_fresh_after_update(self): table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) assert table.is_fresh(READ_ACCESS) @@ -198,7 +195,6 @@ def test_should_become_stale_if_no_writers(self): class RoutingTableUpdateTestCase(TestCase): - def setUp(self): self.table = RoutingTable( [("192.168.1.1", 7687), ("192.168.1.2", 7687)], [("192.168.1.3", 7687)], [], 0) @@ -224,9 +220,119 @@ def test_update_should_replace_ttl(self): class RoutingConnectionPoolConstructionTestCase(TestCase): - def test_should_populate_initial_router(self): initial_router = ("127.0.0.1", 9001) router = ("127.0.0.1", 9002) with RoutingConnectionPool(connector, initial_router, {}, router) as pool: assert pool.routing_table.routers == {("127.0.0.1", 9002)} + + +class FakeConnectionPool(object): + + def __init__(self, addresses): + self._addresses = addresses + + def in_use_connection_count(self, address): + return self._addresses.get(address, 0) + + +class RoundRobinLoadBalancingStrategyTestCase(TestCase): + + def test_simple_reader_selection(self): + strategy = RoundRobinLoadBalancingStrategy() + self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "0.0.0.0") + self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "1.1.1.1") + self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "2.2.2.2") + self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "0.0.0.0") + + def test_empty_reader_selection(self): + strategy = RoundRobinLoadBalancingStrategy() + self.assertIsNone(strategy.select_reader([])) + + def test_simple_writer_selection(self): + strategy = RoundRobinLoadBalancingStrategy() + self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "0.0.0.0") + self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "1.1.1.1") + self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "2.2.2.2") + self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "0.0.0.0") + + def test_empty_writer_selection(self): + strategy = RoundRobinLoadBalancingStrategy() + self.assertIsNone(strategy.select_writer([])) + + +class LeastConnectedLoadBalancingStrategyTestCase(TestCase): + + def test_simple_reader_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("0.0.0.0", 2), + ("1.1.1.1", 1), + ("2.2.2.2", 0), + ]))) + self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "2.2.2.2") + + def test_reader_selection_with_clash(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("0.0.0.0", 0), + ("0.0.0.1", 0), + ("1.1.1.1", 1), + ]))) + self.assertEqual(strategy.select_reader(["0.0.0.0", "0.0.0.1", "1.1.1.1"]), "0.0.0.0") + self.assertEqual(strategy.select_reader(["0.0.0.0", "0.0.0.1", "1.1.1.1"]), "0.0.0.1") + + def test_empty_reader_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ]))) + self.assertIsNone(strategy.select_reader([])) + + def test_not_in_pool_reader_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("1.1.1.1", 1), + ("2.2.2.2", 2), + ]))) + self.assertEqual(strategy.select_reader(["2.2.2.2", "3.3.3.3"]), "3.3.3.3") + + def test_partially_in_pool_reader_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("1.1.1.1", 1), + ("2.2.2.2", 0), + ]))) + self.assertEqual(strategy.select_reader(["2.2.2.2", "3.3.3.3"]), "2.2.2.2") + self.assertEqual(strategy.select_reader(["2.2.2.2", "3.3.3.3"]), "3.3.3.3") + + def test_simple_writer_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("0.0.0.0", 2), + ("1.1.1.1", 1), + ("2.2.2.2", 0), + ]))) + self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "2.2.2.2") + + def test_writer_selection_with_clash(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("0.0.0.0", 0), + ("0.0.0.1", 0), + ("1.1.1.1", 1), + ]))) + self.assertEqual(strategy.select_writer(["0.0.0.0", "0.0.0.1", "1.1.1.1"]), "0.0.0.0") + self.assertEqual(strategy.select_writer(["0.0.0.0", "0.0.0.1", "1.1.1.1"]), "0.0.0.1") + + def test_empty_writer_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ]))) + self.assertIsNone(strategy.select_writer([])) + + def test_not_in_pool_writer_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("1.1.1.1", 1), + ("2.2.2.2", 2), + ]))) + self.assertEqual(strategy.select_writer(["2.2.2.2", "3.3.3.3"]), "3.3.3.3") + + def test_partially_in_pool_writer_selection(self): + strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([ + ("1.1.1.1", 1), + ("2.2.2.2", 0), + ]))) + self.assertEqual(strategy.select_writer(["2.2.2.2", "3.3.3.3"]), "2.2.2.2") + self.assertEqual(strategy.select_writer(["2.2.2.2", "3.3.3.3"]), "3.3.3.3") From 1f4dd6633ed0b29707f4ec63fa5e96839b2acdfc Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Fri, 14 Jul 2017 15:54:24 +0100 Subject: [PATCH 6/7] Leaset connected selection refactoring --- neo4j/v1/routing.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 7b37aeff3..7d3ffbda0 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -235,21 +235,16 @@ def _select(self, offset, addresses): while True: address = addresses[index] + index = (index + 1) % num_addresses + in_use_connections = self._connection_pool.in_use_connection_count(address) if in_use_connections < least_in_use_connections: least_connected_address = address least_in_use_connections = in_use_connections - if index == num_addresses - 1: - index = 0 - else: - index += 1 - if index == start_index: - break - - return least_connected_address + return least_connected_address class RoutingConnectionPool(ConnectionPool): From bb7fc5ea5b637e11c27d986cd3037d70f13b8566 Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Fri, 14 Jul 2017 16:06:29 +0100 Subject: [PATCH 7/7] Load balancing strategy config --- test/stub/test_routingdriver.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/test/stub/test_routingdriver.py b/test/stub/test_routingdriver.py index debe7050a..0edefe666 100644 --- a/test/stub/test_routingdriver.py +++ b/test/stub/test_routingdriver.py @@ -19,7 +19,9 @@ # limitations under the License. -from neo4j.v1 import GraphDatabase, RoutingDriver, READ_ACCESS, WRITE_ACCESS, SessionExpired +from neo4j.v1 import GraphDatabase, READ_ACCESS, WRITE_ACCESS, SessionExpired, \ + RoutingDriver, RoutingConnectionPool, LeastConnectedLoadBalancingStrategy, LOAD_BALANCING_STRATEGY_ROUND_ROBIN, \ + RoundRobinLoadBalancingStrategy from neo4j.bolt import ProtocolError, ServiceUnavailable from test.stub.tools import StubTestCase, StubCluster @@ -215,3 +217,27 @@ def test_should_error_when_missing_reader(self): uri = "bolt+routing://127.0.0.1:9001" with self.assertRaises(ProtocolError): GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) + + def test_default_load_balancing_strategy_is_least_connected(self): + with StubCluster({9001: "router.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + self.assertIsInstance(driver, RoutingDriver) + self.assertIsInstance(driver._pool, RoutingConnectionPool) + self.assertIsInstance(driver._pool.load_balancing_strategy, LeastConnectedLoadBalancingStrategy) + + def test_can_select_round_robin_load_balancing_strategy(self): + with StubCluster({9001: "router.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False, + load_balancing_strategy=LOAD_BALANCING_STRATEGY_ROUND_ROBIN) as driver: + self.assertIsInstance(driver, RoutingDriver) + self.assertIsInstance(driver._pool, RoutingConnectionPool) + self.assertIsInstance(driver._pool.load_balancing_strategy, RoundRobinLoadBalancingStrategy) + + def test_no_other_load_balancing_strategies_are_available(self): + with StubCluster({9001: "router.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with self.assertRaises(ValueError): + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False, load_balancing_strategy=-1): + pass