diff --git a/neo4j/v1/bolt.py b/neo4j/v1/bolt.py index 81a50ede8..f237ed564 100644 --- a/neo4j/v1/bolt.py +++ b/neo4j/v1/bolt.py @@ -83,9 +83,10 @@ class BufferingSocket(object): - def __init__(self, socket): - self.address = socket.getpeername() - self.socket = socket + def __init__(self, connection): + self.connection = connection + self.socket = connection.socket + self.address = self.socket.getpeername() self.buffer = bytearray() def fill(self): @@ -96,6 +97,10 @@ def fill(self): self.buffer[len(self.buffer):] = received else: if ready_to_read is not None: + # If this connection fails, remove this address from the + # connection pool to which this connection belongs. + if self.connection.pool: + self.connection.pool.remove(self.address) raise ServiceUnavailable("Failed to read from connection %r" % (self.address,)) def read_message(self): @@ -211,9 +216,12 @@ class Connection(object): .. note:: logs at INFO level """ + #: The pool of which this connection is a member + pool = None + def __init__(self, sock, **config): self.socket = sock - self.buffering_socket = BufferingSocket(sock) + self.buffering_socket = BufferingSocket(self) self.address = sock.getpeername() self.channel = ChunkChannel(sock) self.packer = Packer(self.channel) @@ -411,6 +419,7 @@ def acquire(self, address): connection.in_use = True return connection connection = self.connector(address) + connection.pool = self connection.in_use = True connections.append(connection) return connection diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 5b24297c6..a408fdfe1 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -263,22 +263,40 @@ def refresh_routing_table(self): def acquire_for_read(self): """ Acquire a connection to a read server. """ - self.refresh_routing_table() - return self.acquire(next(self.routing_table.readers)) + while True: + address = None + while address is None: + self.refresh_routing_table() + address = next(self.routing_table.readers) + try: + connection = self.acquire(address) + except ServiceUnavailable: + self.remove(address) + else: + return connection def acquire_for_write(self): """ Acquire a connection to a write server. """ - self.refresh_routing_table() - return self.acquire(next(self.routing_table.writers)) + while True: + address = None + while address is None: + self.refresh_routing_table() + address = next(self.routing_table.writers) + try: + connection = self.acquire(address) + except ServiceUnavailable: + self.remove(address) + else: + return connection def remove(self, address): """ Remove an address from the connection pool, if present, closing all connections to that address. Also remove from the routing table. """ - super(RoutingConnectionPool, self).remove(address) # We use `discard` instead of `remove` here since the former # will not fail if the address has already been removed. self.routing_table.routers.discard(address) self.routing_table.readers.discard(address) self.routing_table.writers.discard(address) + super(RoutingConnectionPool, self).remove(address) diff --git a/test/resources/fail_on_init.script b/test/resources/fail_on_init.script new file mode 100644 index 000000000..0c8341421 --- /dev/null +++ b/test/resources/fail_on_init.script @@ -0,0 +1,4 @@ +!: AUTO INIT +!: AUTO RESET + +S: diff --git a/test/resources/router_with_multiple_writers.script b/test/resources/router_with_multiple_writers.script new file mode 100644 index 000000000..e3ae63f78 --- /dev/null +++ b/test/resources/router_with_multiple_writers.script @@ -0,0 +1,8 @@ +!: AUTO INIT +!: AUTO RESET + +C: RUN "CALL dbms.cluster.routing.getServers" {} + PULL_ALL +S: SUCCESS {"fields": ["ttl", "servers"]} + RECORD [300, [{"role":"ROUTE","addresses":["127.0.0.1:9001","127.0.0.1:9002","127.0.0.1:9003"]},{"role":"READ","addresses":["127.0.0.1:9004","127.0.0.1:9005"]},{"role":"WRITE","addresses":["127.0.0.1:9006","127.0.0.1:9007"]}]] + SUCCESS {} diff --git a/test/test_routing.py b/test/test_routing.py index 912d4aded..8a30aa3e8 100644 --- a/test/test_routing.py +++ b/test/test_routing.py @@ -577,6 +577,17 @@ def test_connected_to_reader(self): connection = pool.acquire_for_read() assert connection.address in pool.routing_table.readers + def test_should_retry_if_first_reader_fails(self): + with StubCluster({9001: "router.script", + 9004: "fail_on_init.script", + 9005: "empty.script"}): + address = ("127.0.0.1", 9001) + with RoutingConnectionPool(connector, address) as pool: + assert not pool.routing_table.is_fresh() + _ = pool.acquire_for_read() + assert ("127.0.0.1", 9004) not in pool.routing_table.readers + assert ("127.0.0.1", 9005) in pool.routing_table.readers + class RoutingConnectionPoolAcquireForWriteTestCase(ServerTestCase): @@ -596,6 +607,17 @@ def test_connected_to_writer(self): connection = pool.acquire_for_write() assert connection.address in pool.routing_table.writers + def test_should_retry_if_first_writer_fails(self): + with StubCluster({9001: "router_with_multiple_writers.script", + 9006: "fail_on_init.script", + 9007: "empty.script"}): + address = ("127.0.0.1", 9001) + with RoutingConnectionPool(connector, address) as pool: + assert not pool.routing_table.is_fresh() + _ = pool.acquire_for_write() + assert ("127.0.0.1", 9006) not in pool.routing_table.writers + assert ("127.0.0.1", 9007) in pool.routing_table.writers + class RoutingConnectionPoolRemoveTestCase(ServerTestCase): diff --git a/test/util.py b/test/util.py index f4fb09ca5..ff12f5c94 100644 --- a/test/util.py +++ b/test/util.py @@ -85,7 +85,6 @@ class ServerTestCase(TestCase): known_hosts = KNOWN_HOSTS known_hosts_backup = known_hosts + ".backup" - servers = [] def setUp(self): if isfile(self.known_hosts):