diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 7716e0189..7e0c0c009 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -279,16 +279,19 @@ def get_handshake(cls): return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00") @classmethod - def ping(cls, address, *, timeout=None, pool_config=None): + def ping(cls, address, *, deadline=None, pool_config=None): """ Attempt to establish a Bolt connection, returning the agreed Bolt protocol version if successful. """ if pool_config is None: pool_config = PoolConfig() + if deadline is None: + deadline = Deadline(None) try: s, protocol_version, handshake, data = BoltSocket.connect( address, - timeout=timeout, + tcp_timeout=pool_config.connection_timeout, + deadline=deadline, custom_resolver=pool_config.resolver, ssl_context=pool_config.get_ssl_context(), keep_alive=pool_config.keep_alive, @@ -300,38 +303,30 @@ def ping(cls, address, *, timeout=None, pool_config=None): return protocol_version @classmethod - def open(cls, address, *, auth=None, timeout=None, routing_context=None, + def open(cls, address, *, auth=None, deadline=None, routing_context=None, pool_config=None): """ Open a new Bolt connection to a given server address. :param address: :param auth: - :param timeout: the connection timeout in seconds + :param deadline: how long to wait for the connection to be established :param routing_context: dict containing routing context :param pool_config: :return: :raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version. :raise ServiceUnavailable: raised if there was a connection issue. """ - def time_remaining(): - if timeout is None: - return None - t = timeout - (perf_counter() - t0) - return t if t > 0 else 0 - t0 = perf_counter() if pool_config is None: pool_config = PoolConfig() + if deadline is None: + deadline = Deadline(None) socket_connection_timeout = pool_config.connection_timeout - if socket_connection_timeout is None: - socket_connection_timeout = time_remaining() - elif timeout is not None: - socket_connection_timeout = min(pool_config.connection_timeout, - time_remaining()) s, pool_config.protocol_version, handshake, data = BoltSocket.connect( address, - timeout=socket_connection_timeout, + tcp_timeout=pool_config.connection_timeout, + deadline=deadline, custom_resolver=pool_config.resolver, ssl_context=pool_config.get_ssl_context(), keep_alive=pool_config.keep_alive, @@ -370,7 +365,7 @@ def time_remaining(): ) try: - connection.socket.set_deadline(time_remaining()) + connection.socket.set_deadline(deadline) try: connection.hello() finally: @@ -732,9 +727,7 @@ def connection_creator(): released_reservation = False try: try: - connection = self.opener( - address, deadline.to_timeout() - ) + connection = self.opener(address, deadline) except ServiceUnavailable: self.deactivate(address) raise @@ -909,9 +902,9 @@ def open(cls, address, *, auth, pool_config, workspace_config): :return: BoltPool """ - def opener(addr, timeout): + def opener(addr, deadline): return Bolt.open( - addr, auth=auth, timeout=timeout, routing_context=None, + addr, auth=auth, deadline=deadline, routing_context=None, pool_config=pool_config ) @@ -955,8 +948,8 @@ def open(cls, *addresses, auth, pool_config, workspace_config, routing_context=N raise ConfigurationError("The key 'address' is reserved for routing context.") routing_context["address"] = str(address) - def opener(addr, timeout): - return Bolt.open(addr, auth=auth, timeout=timeout, + def opener(addr, deadline): + return Bolt.open(addr, auth=auth, deadline=deadline, routing_context=routing_context, pool_config=pool_config) diff --git a/neo4j/io/_socket.py b/neo4j/io/_socket.py index aee6b78ad..62bf07512 100644 --- a/neo4j/io/_socket.py +++ b/neo4j/io/_socket.py @@ -190,11 +190,12 @@ def _secure(cls, s, host, ssl_context): return s @classmethod - def _handshake(cls, s, resolved_address): + def _handshake(cls, s, resolved_address, deadline): """ :param s: Socket :param resolved_address: + :param deadline: :return: (socket, version, client_handshake, server_response_data) """ @@ -214,46 +215,52 @@ def _handshake(cls, s, resolved_address): log.debug("[#%04X] C: %s %s %s %s", local_port, *supported_versions) - data = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake() - s.sendall(data) + request = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake() # Handle the handshake response - ready_to_read = False - with selectors.DefaultSelector() as selector: - selector.register(s, selectors.EVENT_READ) - selector.select(1) + original_timeout = s.gettimeout() + s.settimeout(deadline.to_timeout()) try: - data = s.recv(4) - except OSError: + s.sendall(request) + response = s.recv(4) + except OSError as exc: raise ServiceUnavailable( - "Failed to read any data from server {!r} " - "after connected".format(resolved_address)) - data_size = len(data) + f"Failed to read any data from server {resolved_address!r} " + f"after connected (deadline {deadline})" + ) from exc + finally: + s.settimeout(original_timeout) + data_size = len(response) if data_size == 0: # If no data is returned after a successful select # response, the server has closed the connection log.debug("[#%04X] S: ", local_port) cls.close_socket(s) raise ServiceUnavailable( - "Connection to {address} closed without handshake response".format( - address=resolved_address)) + f"Connection to {resolved_address} closed without handshake " + "response" + ) if data_size != 4: # Some garbled data has been received log.debug("[#%04X] S: @*#!", local_port) cls.close_socket(s) raise BoltProtocolError( - "Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % ( - resolved_address, data), address=resolved_address) - elif data == b"HTTP": + "Expected four byte Bolt handshake response from " + f"{resolved_address!r}, received {response!r} instead; " + "check for incorrect port number" + , address=resolved_address + ) + elif response == b"HTTP": log.debug("[#%04X] S: ", local_port) cls.close_socket(s) raise ServiceUnavailable( - "Cannot to connect to Bolt service on {!r} " - "(looks like HTTP)".format(resolved_address)) - agreed_version = data[-1], data[-2] + f"Cannot to connect to Bolt service on {resolved_address!r} " + "(looks like HTTP)" + ) + agreed_version = response[-1], response[-2] log.debug("[#%04X] S: 0x%06X%02X", local_port, agreed_version[1], agreed_version[0]) - return cls(s), agreed_version, handshake, data + return cls(s), agreed_version, handshake, response @classmethod def close_socket(cls, socket_): @@ -269,8 +276,8 @@ def close_socket(cls, socket_): pass @classmethod - def connect(cls, address, *, timeout, custom_resolver, ssl_context, - keep_alive): + def connect(cls, address, *, tcp_timeout, deadline, custom_resolver, + ssl_context, keep_alive): """ Connect and perform a handshake and return a valid Connection object, assuming a protocol version can be agreed. """ @@ -281,12 +288,19 @@ def connect(cls, address, *, timeout, custom_resolver, ssl_context, resolved_addresses = Address(address).resolve(resolver=custom_resolver) for resolved_address in resolved_addresses: + deadline_timeout = deadline.to_timeout() + if ( + deadline_timeout is not None + and deadline_timeout <= tcp_timeout + ): + tcp_timeout = deadline_timeout s = None try: - s = BoltSocket._connect(resolved_address, timeout, keep_alive) + s = BoltSocket._connect(resolved_address, tcp_timeout, + keep_alive) s = BoltSocket._secure(s, resolved_address.host_name, ssl_context) - return BoltSocket._handshake(s, resolved_address) + return BoltSocket._handshake(s, resolved_address, deadline) except (BoltError, DriverError, OSError) as error: try: local_port = s.getsockname()[1] diff --git a/tests/unit/io/test_direct.py b/tests/unit/io/test_direct.py index 1801d6b47..d8e27fca4 100644 --- a/tests/unit/io/test_direct.py +++ b/tests/unit/io/test_direct.py @@ -113,19 +113,19 @@ class BoltTestCase(TestCase): def test_open(self): with pytest.raises(ServiceUnavailable): - connection = Bolt.open(("localhost", 9999), auth=("test", "test")) + Bolt.open(("localhost", 9999), auth=("test", "test")) def test_open_timeout(self): - conf = PoolConfig() with pytest.raises(ServiceUnavailable): - connection = Bolt.open(("localhost", 9999), auth=("test", "test"), timeout=1) + Bolt.open(("localhost", 9999), auth=("test", "test"), + deadline=Deadline(1)) def test_ping(self): protocol_version = Bolt.ping(("localhost", 9999)) assert protocol_version is None def test_ping_timeout(self): - protocol_version = Bolt.ping(("localhost", 9999), timeout=1) + protocol_version = Bolt.ping(("localhost", 9999), deadline=Deadline(1)) assert protocol_version is None