Skip to content

Commit dfa99a3

Browse files
authored
Merge pull request #241 from technige/1.6-custom-resolver
[1.7.0] Custom resolver option
2 parents 3cddae8 + f635a0b commit dfa99a3

File tree

5 files changed

+86
-26
lines changed

5 files changed

+86
-26
lines changed

docs/source/driver.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,25 @@ The maximum time to allow for retries to be attempted when using transaction fun
144144
After this time, no more retries will be attempted.
145145
This setting does not terminate running queries.
146146

147+
``resolver``
148+
------------
149+
150+
A custom resolver function to resolve host and port values ahead of DNS resolution.
151+
This function is called with a 2-tuple of (host, port) and should return an iterable of tuples as would be returned from ``getaddrinfo``.
152+
If no custom resolver function is supplied, the internal resolver moves straight to regular DNS resolution.
153+
154+
For example::
155+
156+
def my_resolver(socket_address):
157+
if socket_address == ("foo", 9999):
158+
yield "::1", 7687
159+
yield "127.0.0.1", 7687
160+
else:
161+
from socket import gaierror
162+
raise gaierror("Unexpected socket address %r" % socket_address)
163+
164+
driver = GraphDatabase.driver("bolt+routing://foo:9999", auth=("neo4j", "password"), resolver=my_resolver)
165+
147166

148167

149168
Object Lifetime

neo4j/addressing.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,9 @@
2222
from collections import namedtuple
2323
from socket import getaddrinfo, gaierror, SOCK_STREAM, IPPROTO_TCP
2424

25-
from neo4j.compat import urlparse
25+
from neo4j.compat import urlparse, parse_qs
2626
from neo4j.exceptions import AddressError
2727

28-
try:
29-
from urllib.parse import parse_qs
30-
except ImportError:
31-
from urlparse import parse_qs
32-
3328

3429
VALID_IPv4_SEGMENTS = [str(i).encode("latin1") for i in range(0x100)]
3530
VALID_IPv6_SEGMENT_CHARS = b"0123456789abcdef"
@@ -103,17 +98,46 @@ def parse_routing_context(cls, uri):
10398
return context
10499

105100

106-
def resolve(socket_address):
107-
try:
108-
info = getaddrinfo(socket_address[0], socket_address[1], 0, SOCK_STREAM, IPPROTO_TCP)
109-
except gaierror:
110-
raise AddressError("Cannot resolve address {!r}".format(socket_address[0]))
111-
else:
112-
addresses = []
113-
for _, _, _, _, address in info:
114-
if len(address) == 4 and address[3] != 0:
115-
# skip any IPv6 addresses with a non-zero scope id
116-
# as these appear to cause problems on some platforms
117-
continue
118-
addresses.append(address)
119-
return addresses
101+
class Resolver(object):
102+
""" A Resolver instance stores a list of addresses, each in a tuple, and
103+
provides methods to perform resolution on these, thereby replacing them
104+
with the resolved values.
105+
"""
106+
107+
def __init__(self, custom_resolver=None):
108+
self.addresses = []
109+
self.custom_resolver = custom_resolver
110+
111+
def custom_resolve(self):
112+
""" If a custom resolver is defined, perform custom resolution on
113+
the contained addresses.
114+
115+
:return:
116+
"""
117+
if not callable(self.custom_resolver):
118+
return
119+
new_addresses = []
120+
for address in self.addresses:
121+
for new_address in self.custom_resolver(address):
122+
new_addresses.append(new_address)
123+
self.addresses = new_addresses
124+
125+
def dns_resolve(self):
126+
""" Perform DNS resolution on the contained addresses.
127+
128+
:return:
129+
"""
130+
new_addresses = []
131+
for address in self.addresses:
132+
try:
133+
info = getaddrinfo(address[0], address[1], 0, SOCK_STREAM, IPPROTO_TCP)
134+
except gaierror:
135+
raise AddressError("Cannot resolve address {!r}".format(address))
136+
else:
137+
for _, _, _, _, address in info:
138+
if len(address) == 4 and address[3] != 0:
139+
# skip any IPv6 addresses with a non-zero scope id
140+
# as these appear to cause problems on some platforms
141+
continue
142+
new_addresses.append(address)
143+
self.addresses = new_addresses

neo4j/bolt/connection.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from struct import pack as struct_pack, unpack as struct_unpack
3535
from threading import RLock, Condition
3636

37-
from neo4j.addressing import SocketAddress, resolve
37+
from neo4j.addressing import SocketAddress, Resolver
3838
from neo4j.bolt.cert import KNOWN_HOSTS
3939
from neo4j.bolt.response import InitResponse, AckFailureResponse, ResetResponse
4040
from neo4j.compat.ssl import SSL_AVAILABLE, HAS_SNI, SSLError
@@ -686,12 +686,16 @@ def connect(address, ssl_context=None, error_handler=None, **config):
686686
a protocol version can be agreed.
687687
"""
688688

689+
last_error = None
689690
# Establish a connection to the host and port specified
690691
# Catches refused connections see:
691692
# https://docs.python.org/2/library/errno.html
692693
log_debug("~~ [RESOLVE] %s", address)
693-
last_error = None
694-
for resolved_address in resolve(address):
694+
resolver = Resolver(custom_resolver=config.get("resolver"))
695+
resolver.addresses.append(address)
696+
resolver.custom_resolve()
697+
resolver.dns_resolve()
698+
for resolved_address in resolver.addresses:
695699
log_debug("~~ [RESOLVED] %s -> %s", address, resolved_address)
696700
try:
697701
s = _connect(resolved_address, **config)

neo4j/compat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,6 @@ def perf_counter():
123123

124124
# The location of urlparse varies between Python 2 and 3
125125
try:
126-
from urllib.parse import urlparse
126+
from urllib.parse import urlparse, parse_qs
127127
except ImportError:
128-
from urlparse import urlparse
128+
from urlparse import urlparse, parse_qs

test/integration/test_driver.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
# limitations under the License.
2020

2121

22-
from neo4j.v1 import GraphDatabase, ServiceUnavailable
22+
from neo4j.bolt import DEFAULT_PORT
23+
from neo4j.v1 import GraphDatabase, Driver, ServiceUnavailable
2324
from test.integration.tools import IntegrationTestCase
2425

2526

@@ -44,3 +45,15 @@ def test_fail_nicely_when_using_http_port(self):
4445
with self.assertRaises(ServiceUnavailable):
4546
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False):
4647
pass
48+
49+
def test_custom_resolver(self):
50+
51+
def my_resolver(socket_address):
52+
self.assertEqual(socket_address, ("*", DEFAULT_PORT))
53+
yield "99.99.99.99", self.bolt_port # this should be rejected as unable to connect
54+
yield "127.0.0.1", self.bolt_port # this should succeed
55+
56+
with Driver("bolt://*", auth=self.auth_token, resolver=my_resolver) as driver:
57+
with driver.session() as session:
58+
summary = session.run("RETURN 1").summary()
59+
self.assertEqual(summary.server.address, ("127.0.0.1", 7687))

0 commit comments

Comments
 (0)