Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/driver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,24 @@ The maximum time to allow for retries to be attempted when using transaction fun
After this time, no more retries will be attempted.
This setting does not terminate running queries.

``resolver``
------------

A custom resolver function to use for DNS resolution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can say that it is a hostname or IP address resolver and default implementation performs DNS resolution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Yes, that's better wording.

This function is called with a 2-tuple of (host, port) and should return an iterable of tuples as would be returned from ``getaddrinfo``.

For example::

def my_resolver(socket_address):
if socket_address == ("foo", 9999):
yield "::1", 7687
yield "127.0.0.1", 7687
else:
from socket import gaierror
raise gaierror("Unexpected socket address %r" % socket_address)

driver = GraphDatabase.driver("bolt+routing://foo:9999", auth=("neo4j", "password"), resolver=my_resolver)



Object Lifetime
Expand Down
7 changes: 1 addition & 6 deletions neo4j/addressing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,9 @@
from collections import namedtuple
from socket import getaddrinfo, gaierror, SOCK_STREAM, IPPROTO_TCP

from neo4j.compat import urlparse
from neo4j.compat import urlparse, parse_qs
from neo4j.exceptions import AddressError

try:
from urllib.parse import parse_qs
except ImportError:
from urlparse import parse_qs


VALID_IPv4_SEGMENTS = [str(i).encode("latin1") for i in range(0x100)]
VALID_IPv6_SEGMENT_CHARS = b"0123456789abcdef"
Expand Down
7 changes: 5 additions & 2 deletions neo4j/bolt/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,12 +685,15 @@ def connect(address, ssl_context=None, error_handler=None, **config):
a protocol version can be agreed.
"""

last_error = None
# Establish a connection to the host and port specified
# Catches refused connections see:
# https://docs.python.org/2/library/errno.html
log_debug("~~ [RESOLVE] %s", address)
last_error = None
for resolved_address in resolve(address):
resolver = config.get("resolver")
if not callable(resolver):
resolver = resolve
for resolved_address in resolver(address):
log_debug("~~ [RESOLVED] %s -> %s", address, resolved_address)
try:
s = _connect(resolved_address, **config)
Expand Down
4 changes: 2 additions & 2 deletions neo4j/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,6 @@ def perf_counter():

# The location of urlparse varies between Python 2 and 3
try:
from urllib.parse import urlparse
from urllib.parse import urlparse, parse_qs
except ImportError:
from urlparse import urlparse
from urlparse import urlparse, parse_qs
15 changes: 14 additions & 1 deletion test/integration/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
# limitations under the License.


from neo4j.v1 import GraphDatabase, ServiceUnavailable
from neo4j.bolt import DEFAULT_PORT
from neo4j.v1 import GraphDatabase, Driver, ServiceUnavailable
from test.integration.tools import IntegrationTestCase


Expand All @@ -43,3 +44,15 @@ def test_fail_nicely_when_using_http_port(self):
with self.assertRaises(ServiceUnavailable):
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False):
pass

def test_custom_resolver(self):

def my_resolver(socket_address):
self.assertEqual(socket_address, ("*", DEFAULT_PORT))
yield "99.99.99.99", self.bolt_port # this should be rejected as unable to connect
yield "127.0.0.1", self.bolt_port # this should succeed

with Driver("bolt://*", auth=self.auth_token, resolver=my_resolver) as driver:
with driver.session() as session:
summary = session.run("RETURN 1").summary()
self.assertEqual(summary.server.address, ("127.0.0.1", 7687))