2929from __future__ import division
3030
3131from base64 import b64encode
32- from collections import deque
32+ from collections import deque , namedtuple
3333from io import BytesIO
3434import logging
3535from os import makedirs , open as os_open , write as os_write , close as os_close , O_CREAT , O_APPEND , O_WRONLY
8181log_error = log .error
8282
8383
84+ Address = namedtuple ("Address" , ["host" , "port" ])
85+ ServerInfo = namedtuple ("ServerInfo" , ["address" , "version" ])
86+
87+
8488class BufferingSocket (object ):
8589
8690 def __init__ (self , connection ):
8791 self .connection = connection
8892 self .socket = connection .socket
89- self .address = self .socket .getpeername ()
93+ self .address = Address ( * self .socket .getpeername () )
9094 self .buffer = bytearray ()
9195
9296 def fill (self ):
@@ -132,7 +136,7 @@ class ChunkChannel(object):
132136
133137 def __init__ (self , sock ):
134138 self .socket = sock
135- self .address = sock .getpeername ()
139+ self .address = Address ( * sock .getpeername () )
136140 self .raw = BytesIO ()
137141 self .output_buffer = []
138142 self .output_size = 0
@@ -206,6 +210,22 @@ def on_ignored(self, metadata=None):
206210 pass
207211
208212
213+ class InitResponse (Response ):
214+
215+ def on_success (self , metadata ):
216+ super (InitResponse , self ).on_success (metadata )
217+ connection = self .connection
218+ address = Address (* connection .socket .getpeername ())
219+ version = metadata .get ("server" )
220+ connection .server = ServerInfo (address , version )
221+
222+ def on_failure (self , metadata ):
223+ code = metadata .get ("code" )
224+ error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else
225+ ServiceUnavailable )
226+ raise error (metadata .get ("message" , "INIT failed" ))
227+
228+
209229class Connection (object ):
210230 """ Server connection for Bolt protocol v1.
211231
@@ -219,10 +239,12 @@ class Connection(object):
219239 #: The pool of which this connection is a member
220240 pool = None
221241
242+ #: Server version details
243+ server = None
244+
222245 def __init__ (self , sock , ** config ):
223246 self .socket = sock
224247 self .buffering_socket = BufferingSocket (self )
225- self .address = sock .getpeername ()
226248 self .channel = ChunkChannel (sock )
227249 self .packer = Packer (self .channel )
228250 self .unpacker = Unpacker ()
@@ -246,15 +268,7 @@ def __init__(self, sock, **config):
246268 # Pick up the server certificate, if any
247269 self .der_encoded_server_certificate = config .get ("der_encoded_server_certificate" )
248270
249- def on_failure (metadata ):
250- code = metadata .get ("code" )
251- error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else
252- ServiceUnavailable )
253- raise error (metadata .get ("message" , "INIT failed" ))
254-
255- response = Response (self )
256- response .on_failure = on_failure
257-
271+ response = InitResponse (self )
258272 self .append (INIT , (self .user_agent , self .auth_dict ), response = response )
259273 self .send ()
260274 while not response .complete :
@@ -316,18 +330,18 @@ def send(self):
316330 """ Send all queued messages to the server.
317331 """
318332 if self .closed :
319- raise ServiceUnavailable ("Failed to write to closed connection %r" % (self .address ,))
333+ raise ServiceUnavailable ("Failed to write to closed connection %r" % (self .server . address ,))
320334 if self .defunct :
321- raise ServiceUnavailable ("Failed to write to defunct connection %r" % (self .address ,))
335+ raise ServiceUnavailable ("Failed to write to defunct connection %r" % (self .server . address ,))
322336 self .channel .send ()
323337
324338 def fetch (self ):
325339 """ Receive exactly one message from the server.
326340 """
327341 if self .closed :
328- raise ServiceUnavailable ("Failed to read from closed connection %r" % (self .address ,))
342+ raise ServiceUnavailable ("Failed to read from closed connection %r" % (self .server . address ,))
329343 if self .defunct :
330- raise ServiceUnavailable ("Failed to read from defunct connection %r" % (self .address ,))
344+ raise ServiceUnavailable ("Failed to read from defunct connection %r" % (self .server . address ,))
331345 try :
332346 message_data = self .buffering_socket .read_message ()
333347 except ProtocolError :
0 commit comments