diff --git a/docs/source/driver.rst b/docs/source/driver.rst index 10743abe4..4df88ea8d 100644 --- a/docs/source/driver.rst +++ b/docs/source/driver.rst @@ -144,6 +144,25 @@ The maximum time to allow for retries to be attempted when using transaction fun After this time, no more retries will be attempted. This setting does not terminate running queries. +``resolver`` +------------ + +A custom resolver function to resolve host and port values ahead of DNS resolution. +This function is called with a 2-tuple of (host, port) and should return an iterable of tuples as would be returned from ``getaddrinfo``. +If no custom resolver function is supplied, the internal resolver moves straight to regular DNS resolution. + +For example:: + + def my_resolver(socket_address): + if socket_address == ("foo", 9999): + yield "::1", 7687 + yield "127.0.0.1", 7687 + else: + from socket import gaierror + raise gaierror("Unexpected socket address %r" % socket_address) + + driver = GraphDatabase.driver("bolt+routing://foo:9999", auth=("neo4j", "password"), resolver=my_resolver) + Object Lifetime diff --git a/neo4j/addressing.py b/neo4j/addressing.py index 5ada9b8e6..88af80dd8 100644 --- a/neo4j/addressing.py +++ b/neo4j/addressing.py @@ -21,14 +21,9 @@ from collections import namedtuple from socket import getaddrinfo, gaierror, SOCK_STREAM, IPPROTO_TCP -from neo4j.compat import urlparse +from neo4j.compat import urlparse, parse_qs from neo4j.exceptions import AddressError -try: - from urllib.parse import parse_qs -except ImportError: - from urlparse import parse_qs - VALID_IPv4_SEGMENTS = [str(i).encode("latin1") for i in range(0x100)] VALID_IPv6_SEGMENT_CHARS = b"0123456789abcdef" @@ -102,17 +97,46 @@ def parse_routing_context(cls, uri): return context -def resolve(socket_address): - try: - info = getaddrinfo(socket_address[0], socket_address[1], 0, SOCK_STREAM, IPPROTO_TCP) - except gaierror: - raise AddressError("Cannot resolve address {!r}".format(socket_address[0])) - else: - addresses = [] - for _, _, _, _, address in info: - if len(address) == 4 and address[3] != 0: - # skip any IPv6 addresses with a non-zero scope id - # as these appear to cause problems on some platforms - continue - addresses.append(address) - return addresses +class Resolver(object): + """ A Resolver instance stores a list of addresses, each in a tuple, and + provides methods to perform resolution on these, thereby replacing them + with the resolved values. + """ + + def __init__(self, custom_resolver=None): + self.addresses = [] + self.custom_resolver = custom_resolver + + def custom_resolve(self): + """ If a custom resolver is defined, perform custom resolution on + the contained addresses. + + :return: + """ + if not callable(self.custom_resolver): + return + new_addresses = [] + for address in self.addresses: + for new_address in self.custom_resolver(address): + new_addresses.append(new_address) + self.addresses = new_addresses + + def dns_resolve(self): + """ Perform DNS resolution on the contained addresses. + + :return: + """ + new_addresses = [] + for address in self.addresses: + try: + info = getaddrinfo(address[0], address[1], 0, SOCK_STREAM, IPPROTO_TCP) + except gaierror: + raise AddressError("Cannot resolve address {!r}".format(address)) + else: + for _, _, _, _, address in info: + if len(address) == 4 and address[3] != 0: + # skip any IPv6 addresses with a non-zero scope id + # as these appear to cause problems on some platforms + continue + new_addresses.append(address) + self.addresses = new_addresses diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index 3c4b77d51..21ce86521 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -33,7 +33,7 @@ from struct import pack as struct_pack, unpack as struct_unpack from threading import RLock, Condition -from neo4j.addressing import SocketAddress, resolve +from neo4j.addressing import SocketAddress, Resolver from neo4j.bolt.cert import KNOWN_HOSTS from neo4j.bolt.response import InitResponse, AckFailureResponse, ResetResponse from neo4j.compat.ssl import SSL_AVAILABLE, HAS_SNI, SSLError @@ -685,12 +685,16 @@ def connect(address, ssl_context=None, error_handler=None, **config): a protocol version can be agreed. """ + last_error = None # Establish a connection to the host and port specified # Catches refused connections see: # https://docs.python.org/2/library/errno.html log_debug("~~ [RESOLVE] %s", address) - last_error = None - for resolved_address in resolve(address): + resolver = Resolver(custom_resolver=config.get("resolver")) + resolver.addresses.append(address) + resolver.custom_resolve() + resolver.dns_resolve() + for resolved_address in resolver.addresses: log_debug("~~ [RESOLVED] %s -> %s", address, resolved_address) try: s = _connect(resolved_address, **config) diff --git a/neo4j/compat/__init__.py b/neo4j/compat/__init__.py index 2a49a2d03..c175c49aa 100644 --- a/neo4j/compat/__init__.py +++ b/neo4j/compat/__init__.py @@ -122,6 +122,6 @@ def perf_counter(): # The location of urlparse varies between Python 2 and 3 try: - from urllib.parse import urlparse + from urllib.parse import urlparse, parse_qs except ImportError: - from urlparse import urlparse + from urlparse import urlparse, parse_qs diff --git a/test/integration/test_driver.py b/test/integration/test_driver.py index f921ea57e..3c29eeab4 100644 --- a/test/integration/test_driver.py +++ b/test/integration/test_driver.py @@ -18,7 +18,8 @@ # limitations under the License. -from neo4j.v1 import GraphDatabase, ServiceUnavailable +from neo4j.bolt import DEFAULT_PORT +from neo4j.v1 import GraphDatabase, Driver, ServiceUnavailable from test.integration.tools import IntegrationTestCase @@ -43,3 +44,15 @@ def test_fail_nicely_when_using_http_port(self): with self.assertRaises(ServiceUnavailable): with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False): pass + + def test_custom_resolver(self): + + def my_resolver(socket_address): + self.assertEqual(socket_address, ("*", DEFAULT_PORT)) + yield "99.99.99.99", self.bolt_port # this should be rejected as unable to connect + yield "127.0.0.1", self.bolt_port # this should succeed + + with Driver("bolt://*", auth=self.auth_token, resolver=my_resolver) as driver: + with driver.session() as session: + summary = session.run("RETURN 1").summary() + self.assertEqual(summary.server.address, ("127.0.0.1", 7687))