From ebb916046a19fef37c62ec78e53709ab697003ed Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 9 Oct 2024 11:48:56 +0200 Subject: [PATCH 01/18] Refactor bolt socket Move handshake and connect method that can be automatically un-async-ed for less code duplication. --- src/neo4j/_async/io/_bolt.py | 2 +- src/neo4j/_async/io/_bolt_socket.py | 207 +++++++++ src/neo4j/_async_compat/network/__init__.py | 8 +- .../_async_compat/network/_bolt_socket.py | 409 +++--------------- src/neo4j/_sync/io/_bolt.py | 2 +- src/neo4j/_sync/io/_bolt_socket.py | 207 +++++++++ tests/unit/async_/io/test_class_bolt.py | 2 +- tests/unit/sync/io/test_class_bolt.py | 2 +- 8 files changed, 476 insertions(+), 363 deletions(-) create mode 100644 src/neo4j/_async/io/_bolt_socket.py create mode 100644 src/neo4j/_sync/io/_bolt_socket.py diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 87c784569..2ae6c0569 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -23,7 +23,6 @@ from logging import getLogger from time import monotonic -from ..._async_compat.network import AsyncBoltSocket from ..._async_compat.util import AsyncUtil from ..._codec.hydration import ( HydrationHandlerABC, @@ -52,6 +51,7 @@ SessionExpired, ) from ..config import AsyncPoolConfig +from ._bolt_socket import AsyncBoltSocket from ._common import ( AsyncInbox, AsyncOutbox, diff --git a/src/neo4j/_async/io/_bolt_socket.py b/src/neo4j/_async/io/_bolt_socket.py new file mode 100644 index 000000000..0b453059d --- /dev/null +++ b/src/neo4j/_async/io/_bolt_socket.py @@ -0,0 +1,207 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import logging +import struct +from contextlib import suppress + +from ... import addressing +from ..._async_compat.network import ( + AsyncBoltSocketBase, + AsyncNetworkUtil, +) +from ..._exceptions import ( + BoltError, + BoltProtocolError, +) +from ...exceptions import ( + DriverError, + ServiceUnavailable, +) + + +log = logging.getLogger("neo4j.io") + + +class AsyncBoltSocket(AsyncBoltSocketBase): + async def _handshake(self, resolved_address, deadline): + """ + Perform BOLT handshake. + + :param resolved_address: + :param deadline: Deadline for handshake + + :returns: (version, client_handshake, server_response_data) + """ + local_port = self.getsockname()[1] + + # TODO: Optimize logging code + handshake = self.Bolt.get_handshake() + handshake = struct.unpack(">16B", handshake) + handshake = [handshake[i : i + 4] for i in range(0, len(handshake), 4)] + + supported_versions = [ + f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}" + for vx in handshake + ] + + log.debug( + "[#%04X] C: 0x%08X", + local_port, + int.from_bytes(self.Bolt.MAGIC_PREAMBLE, byteorder="big"), + ) + log.debug( + "[#%04X] C: %s %s %s %s", + local_port, + *supported_versions, + ) + + request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() + + # Handle the handshake response + original_timeout = self.gettimeout() + self.settimeout(deadline.to_timeout()) + try: + await self.sendall(request) + response = await self.recv(4) + except OSError as exc: + raise ServiceUnavailable( + f"Failed to read any data from server {resolved_address!r} " + f"after connected (deadline {deadline})" + ) from exc + finally: + self.settimeout(original_timeout) + data_size = len(response) + if data_size == 0: + # If no data is returned after a successful select + # response, the server has closed the connection + log.debug("[#%04X] S: ", local_port) + await self.close() + raise ServiceUnavailable( + f"Connection to {resolved_address} closed without handshake " + "response" + ) + if data_size != 4: + # Some garbled data has been received + log.debug("[#%04X] S: @*#!", local_port) + await self.close() + raise BoltProtocolError( + "Expected four byte Bolt handshake response from " + f"{resolved_address!r}, received {response!r} instead; " + "check for incorrect port number", + address=resolved_address, + ) + elif response == b"HTTP": + log.debug("[#%04X] S: ", local_port) + await self.close() + raise ServiceUnavailable( + f"Cannot to connect to Bolt service on {resolved_address!r} " + "(looks like HTTP)" + ) + agreed_version = response[-1], response[-2] + log.debug( + "[#%04X] S: 0x%06X%02X", + local_port, + agreed_version[1], + agreed_version[0], + ) + return agreed_version, handshake, response + + @classmethod + async def connect( + cls, + address, + *, + tcp_timeout, + deadline, + custom_resolver, + ssl_context, + keep_alive, + ): + """ + Connect and perform a handshake. + + Return a valid Connection object, assuming a protocol version can be + agreed. + """ + errors = [] + failed_addresses = [] + # Establish a connection to the host and port specified + # Catches refused connections see: + # https://docs.python.org/2/library/errno.html + + resolved_addresses = AsyncNetworkUtil.resolve_address( + addressing.Address(address), resolver=custom_resolver + ) + async for resolved_address in resolved_addresses: + deadline_timeout = deadline.to_timeout() + if ( + deadline_timeout is not None + and deadline_timeout <= tcp_timeout + ): + tcp_timeout = deadline_timeout + s = None + try: + s = await cls._connect_secure( + resolved_address, tcp_timeout, keep_alive, ssl_context + ) + return (s, *await s._handshake(resolved_address, deadline)) + except (BoltError, DriverError, OSError) as error: + try: + local_port = s.getsockname()[1] + except (OSError, AttributeError, TypeError): + local_port = 0 + err_str = error.__class__.__name__ + if str(error): + err_str += ": " + str(error) + log.debug( + "[#%04X] S: %s %s", + local_port, + resolved_address, + err_str, + ) + if s: + await cls.close_socket(s) + errors.append(error) + failed_addresses.append(resolved_address) + except asyncio.CancelledError: + try: + local_port = s.getsockname()[1] + except (OSError, AttributeError, TypeError): + local_port = 0 + log.debug( + "[#%04X] C: %s", local_port, resolved_address + ) + if s: + with suppress(OSError): + s.kill() + raise + except Exception: + if s: + await cls.close_socket(s) + raise + address_strs = tuple(map(str, failed_addresses)) + if not errors: + raise ServiceUnavailable( + f"Couldn't connect to {address} (resolved to {address_strs})" + ) + else: + error_strs = "\n".join(map(str, errors)) + raise ServiceUnavailable( + f"Couldn't connect to {address} (resolved to {address_strs}):" + f"\n{error_strs}" + ) from errors[0] diff --git a/src/neo4j/_async_compat/network/__init__.py b/src/neo4j/_async_compat/network/__init__.py index e5777595f..4da0f32e6 100644 --- a/src/neo4j/_async_compat/network/__init__.py +++ b/src/neo4j/_async_compat/network/__init__.py @@ -15,8 +15,8 @@ from ._bolt_socket import ( - AsyncBoltSocket, - BoltSocket, + AsyncBoltSocketBase, + BoltSocketBase, ) from ._util import ( AsyncNetworkUtil, @@ -25,8 +25,8 @@ __all__ = [ - "AsyncBoltSocket", + "AsyncBoltSocketBase", "AsyncNetworkUtil", - "BoltSocket", + "BoltSocketBase", "NetworkUtil", ] diff --git a/src/neo4j/_async_compat/network/_bolt_socket.py b/src/neo4j/_async_compat/network/_bolt_socket.py index 5f92f831d..ec14efa24 100644 --- a/src/neo4j/_async_compat/network/_bolt_socket.py +++ b/src/neo4j/_async_compat/network/_bolt_socket.py @@ -16,9 +16,9 @@ from __future__ import annotations +import abc import asyncio import logging -import struct import typing as t from contextlib import suppress @@ -44,23 +44,14 @@ SSLSocket, ) -from ... import addressing from ..._deadline import Deadline from ..._exceptions import ( - BoltError, BoltProtocolError, BoltSecurityError, SocketDeadlineExceededError, ) -from ...exceptions import ( - DriverError, - ServiceUnavailable, -) +from ...exceptions import ServiceUnavailable from ..shims import wait_for -from ._util import ( - AsyncNetworkUtil, - NetworkUtil, -) if t.TYPE_CHECKING: @@ -82,7 +73,7 @@ def _sanitize_deadline(deadline): return deadline -class AsyncBoltSocket: +class AsyncBoltSocketBase(abc.ABC): Bolt: te.Final[type[AsyncBolt]] = None # type: ignore[assignment] def __init__(self, reader, protocol, writer): @@ -177,19 +168,22 @@ def kill(self): self._writer.close() @classmethod - async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): + async def _connect_secure( + cls, resolved_address, timeout, keep_alive, ssl_context + ): """ Connect to the address and return the socket. :param resolved_address: :param timeout: seconds :param keep_alive: True or False - :param ssl: SSLContext or None + :param ssl_context: SSLContext or None :returns: AsyncBoltSocket object """ loop = asyncio.get_event_loop() s = None + local_port = 0 # TODO: tomorrow me: fix this mess try: @@ -202,17 +196,18 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): s.setblocking(False) # asyncio + blocking = no-no! log.debug("[#0000] C: %s", resolved_address) await wait_for(loop.sock_connect(s, resolved_address), timeout) + local_port = s.getsockname()[1] keep_alive = 1 if keep_alive else 0 s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) ssl_kwargs = {} - if ssl is not None: + if ssl_context is not None: hostname = resolved_address._host_name or None - ssl_kwargs.update( - ssl=ssl, server_hostname=hostname if HAS_SNI else None - ) + sni_host = hostname if HAS_SNI and hostname else None + ssl_kwargs.update(ssl=ssl_context, server_hostname=sni_host) + log.debug("[#%04X] C: %s", local_port, hostname) reader = asyncio.StreamReader( limit=2**16, # 64 KiB, @@ -224,7 +219,7 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): ) writer = asyncio.StreamWriter(transport, protocol, reader, loop) - if ssl is not None: + if ssl_context is not None: # Check that the server provides a certificate der_encoded_server_certificate = transport.get_extra_info( "ssl_object" @@ -255,7 +250,6 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): cls._kill_raw_socket(s) raise except (SSLError, CertificateError) as error: - local_port = s.getsockname()[1] if s: cls._kill_raw_socket(s) raise BoltSecurityError( @@ -278,92 +272,25 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): ) from error raise - async def _handshake(self, resolved_address, deadline): - """ - Perform BOLT handshake. + @abc.abstractmethod + async def _handshake(self, resolved_address, deadline): ... - :param resolved_address: - :param deadline: Deadline for handshake - - :returns: (socket, version, client_handshake, server_response_data) - """ - local_port = self.getsockname()[1] - - # TODO: Optimize logging code - handshake = self.Bolt.get_handshake() - handshake = struct.unpack(">16B", handshake) - handshake = [handshake[i : i + 4] for i in range(0, len(handshake), 4)] - - supported_versions = [ - f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}" - for vx in handshake - ] - - log.debug( - "[#%04X] C: 0x%08X", - local_port, - int.from_bytes(self.Bolt.MAGIC_PREAMBLE, byteorder="big"), - ) - log.debug( - "[#%04X] C: %s %s %s %s", - local_port, - *supported_versions, - ) - - request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() - - # Handle the handshake response - original_timeout = self.gettimeout() - self.settimeout(deadline.to_timeout()) - try: - await self.sendall(request) - response = await self.recv(4) - except OSError as exc: - raise ServiceUnavailable( - f"Failed to read any data from server {resolved_address!r} " - f"after connected (deadline {deadline})" - ) from exc - finally: - self.settimeout(original_timeout) - data_size = len(response) - if data_size == 0: - # If no data is returned after a successful select - # response, the server has closed the connection - log.debug("[#%04X] S: ", local_port) - await self.close() - raise ServiceUnavailable( - f"Connection to {resolved_address} closed without handshake " - "response" - ) - if data_size != 4: - # Some garbled data has been received - log.debug("[#%04X] S: @*#!", local_port) - await self.close() - raise BoltProtocolError( - "Expected four byte Bolt handshake response from " - f"{resolved_address!r}, received {response!r} instead; " - "check for incorrect port number", - address=resolved_address, - ) - elif response == b"HTTP": - log.debug("[#%04X] S: ", local_port) - await self.close() - raise ServiceUnavailable( - f"Cannot to connect to Bolt service on {resolved_address!r} " - "(looks like HTTP)" - ) - agreed_version = response[-1], response[-2] - log.debug( - "[#%04X] S: 0x%06X%02X", - local_port, - agreed_version[1], - agreed_version[0], - ) - return self, agreed_version, handshake, response + @classmethod + @abc.abstractmethod + async def connect( + cls, + address, + *, + tcp_timeout, + deadline, + custom_resolver, + ssl_context, + keep_alive, + ): ... @classmethod async def close_socket(cls, socket_): - if isinstance(socket_, AsyncBoltSocket): + if isinstance(socket_, AsyncBoltSocketBase): with suppress(OSError): await socket_.close() else: @@ -376,93 +303,8 @@ def _kill_raw_socket(cls, socket_): with suppress(OSError): socket_.close() - @classmethod - async def connect( - cls, - address, - *, - tcp_timeout, - deadline, - custom_resolver, - ssl_context, - keep_alive, - ): - """ - Connect and perform a handshake. - - Return a valid Connection object, assuming a protocol version can be - agreed. - """ - errors = [] - failed_addresses = [] - # Establish a connection to the host and port specified - # Catches refused connections see: - # https://docs.python.org/2/library/errno.html - - resolved_addresses = AsyncNetworkUtil.resolve_address( - addressing.Address(address), resolver=custom_resolver - ) - async for resolved_address in resolved_addresses: - deadline_timeout = deadline.to_timeout() - if ( - deadline_timeout is not None - and deadline_timeout <= tcp_timeout - ): - tcp_timeout = deadline_timeout - s = None - try: - s = await cls._connect_secure( - resolved_address, tcp_timeout, keep_alive, ssl_context - ) - return await s._handshake(resolved_address, deadline) - except (BoltError, DriverError, OSError) as error: - try: - local_port = s.getsockname()[1] - except (OSError, AttributeError, TypeError): - local_port = 0 - err_str = error.__class__.__name__ - if str(error): - err_str += ": " + str(error) - log.debug( - "[#%04X] C: %s %s", - local_port, - resolved_address, - err_str, - ) - if s: - await cls.close_socket(s) - errors.append(error) - failed_addresses.append(resolved_address) - except asyncio.CancelledError: - try: - local_port = s.getsockname()[1] - except (OSError, AttributeError, TypeError): - local_port = 0 - log.debug( - "[#%04X] C: %s", local_port, resolved_address - ) - if s: - with suppress(OSError): - s.kill() - raise - except Exception: - if s: - await cls.close_socket(s) - raise - address_strs = tuple(map(str, failed_addresses)) - if not errors: - raise ServiceUnavailable( - f"Couldn't connect to {address} (resolved to {address_strs})" - ) - else: - error_strs = "\n".join(map(str, errors)) - raise ServiceUnavailable( - f"Couldn't connect to {address} (resolved to {address_strs}):" - f"\n{error_strs}" - ) from errors[0] - -class BoltSocket: +class BoltSocketBase: Bolt: te.Final[type[Bolt]] = None # type: ignore[assignment] def __init__(self, socket_: socket): @@ -530,7 +372,9 @@ def kill(self): self._socket.close() @classmethod - def _connect(cls, resolved_address, timeout, keep_alive): + def _connect_secure( + cls, resolved_address, timeout, keep_alive, ssl_context + ): """ Connect to the address and return the socket. @@ -556,7 +400,6 @@ def _connect(cls, resolved_address, timeout, keep_alive): s.settimeout(t) keep_alive = 1 if keep_alive else 0 s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) - return s except SocketTimeout: log.debug("[#0000] S: %s", resolved_address) log.debug("[#0000] C: %s", resolved_address) @@ -580,20 +423,19 @@ def _connect(cls, resolved_address, timeout, keep_alive): ) from error raise - @classmethod - def _secure(cls, s, host, ssl_context): local_port = s.getsockname()[1] # Secure the connection if an SSL context has been provided if ssl_context: - log.debug("[#%04X] C: %s", local_port, host) + hostname = resolved_address._host_name or None + sni_host = hostname if HAS_SNI and hostname else None + log.debug("[#%04X] C: %s", local_port, hostname) try: - sni_host = host if HAS_SNI and host else None s = ssl_context.wrap_socket(s, server_hostname=sni_host) except (OSError, SSLError, CertificateError) as cause: cls._kill_raw_socket(s) raise BoltSecurityError( message="Failed to establish encrypted connection.", - address=(host, local_port), + address=(hostname, local_port), ) from cause # Check that the server provides a certificate der_encoded_server_certificate = s.getpeercert(binary_form=True) @@ -601,110 +443,16 @@ def _secure(cls, s, host, ssl_context): raise BoltProtocolError( "When using an encrypted socket, the server should always " "provide a certificate", - address=(host, local_port), + address=(hostname, local_port), ) - return s - return s - - @classmethod - def _handshake(cls, s, resolved_address, deadline): - """ - Perform BOLT handshake. - - :param s: Socket - :param resolved_address: - :param deadline: - - :returns: (socket, version, client_handshake, server_response_data) - """ - local_port = s.getsockname()[1] - - # TODO: Optimize logging code - handshake = cls.Bolt.get_handshake() - handshake = struct.unpack(">16B", handshake) - handshake = [handshake[i : i + 4] for i in range(0, len(handshake), 4)] - - supported_versions = [ - f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}" - for vx in handshake - ] - - log.debug( - "[#%04X] C: 0x%08X", - local_port, - int.from_bytes(cls.Bolt.MAGIC_PREAMBLE, byteorder="big"), - ) - log.debug( - "[#%04X] C: %s %s %s %s", - local_port, - *supported_versions, - ) - - request = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake() - # Handle the handshake response - original_timeout = s.gettimeout() - s.settimeout(deadline.to_timeout()) - try: - s.sendall(request) - response = s.recv(4) - except OSError as exc: - raise ServiceUnavailable( - f"Failed to read any data from server {resolved_address!r} " - f"after connected (deadline {deadline})" - ) from exc - finally: - s.settimeout(original_timeout) - data_size = len(response) - if data_size == 0: - # If no data is returned after a successful select - # response, the server has closed the connection - log.debug("[#%04X] S: ", local_port) - cls._kill_raw_socket(s) - raise ServiceUnavailable( - f"Connection to {resolved_address} closed without handshake " - "response" - ) - if data_size != 4: - # Some garbled data has been received - log.debug("[#%04X] S: @*#!", local_port) - cls._kill_raw_socket(s) - raise BoltProtocolError( - "Expected four byte Bolt handshake response from " - f"{resolved_address!r}, received {response!r} instead; " - "check for incorrect port number", - address=resolved_address, - ) - elif response == b"HTTP": - log.debug("[#%04X] S: ", local_port) - cls._kill_raw_socket(s) - raise ServiceUnavailable( - f"Cannot to connect to Bolt service on {resolved_address!r} " - "(looks like HTTP)" - ) - agreed_version = response[-1], response[-2] - log.debug( - "[#%04X] S: 0x%06X%02X", - local_port, - agreed_version[1], - agreed_version[0], - ) - return cls(s), agreed_version, handshake, response - - @classmethod - def close_socket(cls, socket_): - if isinstance(socket_, BoltSocket): - socket_ = socket_._socket - cls._kill_raw_socket(socket_) + return cls(s) - @classmethod - def _kill_raw_socket(cls, socket_): - with suppress(OSError): - socket_.shutdown(SHUT_RDWR) - with suppress(OSError): - socket_.close() + @abc.abstractmethod + def _handshake(self, resolved_address, deadline): ... @classmethod + @abc.abstractmethod def connect( cls, address, @@ -714,66 +462,17 @@ def connect( custom_resolver, ssl_context, keep_alive, - ): - """ - Connect and perform a handshake. + ): ... - Return a valid Connection object, assuming a protocol version can be - agreed. - """ - errors = [] - # Establish a connection to the host and port specified - # Catches refused connections see: - # https://docs.python.org/2/library/errno.html + @classmethod + def close_socket(cls, socket_): + if isinstance(socket_, BoltSocketBase): + socket_ = socket_._socket + cls._kill_raw_socket(socket_) - resolved_addresses = NetworkUtil.resolve_address( - addressing.Address(address), resolver=custom_resolver - ) - for resolved_address in resolved_addresses: - deadline_timeout = deadline.to_timeout() - if ( - deadline_timeout is not None - and deadline_timeout <= tcp_timeout - ): - tcp_timeout = deadline_timeout - s = None - try: - s = BoltSocket._connect( - resolved_address, tcp_timeout, keep_alive - ) - s = BoltSocket._secure( - s, resolved_address._host_name, ssl_context - ) - return BoltSocket._handshake(s, resolved_address, deadline) - except (BoltError, DriverError, OSError) as error: - try: - local_port = s.getsockname()[1] - except (OSError, AttributeError): - local_port = 0 - err_str = error.__class__.__name__ - if str(error): - err_str += ": " + str(error) - log.debug( - "[#%04X] S: %s", local_port, err_str - ) - if s: - cls.close_socket(s) - errors.append(error) - except Exception: - if s: - cls.close_socket(s) - raise - if not errors: - resolved_address_strs = tuple(map(str, resolved_addresses)) - raise ServiceUnavailable( - f"Couldn't connect to {address} " - f"(resolved to {resolved_address_strs})" - ) - else: - resolved_address_strs = tuple(map(str, resolved_addresses)) - error_strs = "\n".join(map(str, errors)) - raise ServiceUnavailable( - f"Couldn't connect to {address} " - f"(resolved to {resolved_address_strs}):\n" - f"{error_strs}" - ) from errors[0] + @classmethod + def _kill_raw_socket(cls, socket_): + with suppress(OSError): + socket_.shutdown(SHUT_RDWR) + with suppress(OSError): + socket_.close() diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 3aa2f020d..9173813d9 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -23,7 +23,6 @@ from logging import getLogger from time import monotonic -from ..._async_compat.network import BoltSocket from ..._async_compat.util import Util from ..._codec.hydration import ( HydrationHandlerABC, @@ -52,6 +51,7 @@ SessionExpired, ) from ..config import PoolConfig +from ._bolt_socket import BoltSocket from ._common import ( CommitResponse, Inbox, diff --git a/src/neo4j/_sync/io/_bolt_socket.py b/src/neo4j/_sync/io/_bolt_socket.py new file mode 100644 index 000000000..23cf0257b --- /dev/null +++ b/src/neo4j/_sync/io/_bolt_socket.py @@ -0,0 +1,207 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import logging +import struct +from contextlib import suppress + +from ... import addressing +from ..._async_compat.network import ( + BoltSocketBase, + NetworkUtil, +) +from ..._exceptions import ( + BoltError, + BoltProtocolError, +) +from ...exceptions import ( + DriverError, + ServiceUnavailable, +) + + +log = logging.getLogger("neo4j.io") + + +class BoltSocket(BoltSocketBase): + def _handshake(self, resolved_address, deadline): + """ + Perform BOLT handshake. + + :param resolved_address: + :param deadline: Deadline for handshake + + :returns: (version, client_handshake, server_response_data) + """ + local_port = self.getsockname()[1] + + # TODO: Optimize logging code + handshake = self.Bolt.get_handshake() + handshake = struct.unpack(">16B", handshake) + handshake = [handshake[i : i + 4] for i in range(0, len(handshake), 4)] + + supported_versions = [ + f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}" + for vx in handshake + ] + + log.debug( + "[#%04X] C: 0x%08X", + local_port, + int.from_bytes(self.Bolt.MAGIC_PREAMBLE, byteorder="big"), + ) + log.debug( + "[#%04X] C: %s %s %s %s", + local_port, + *supported_versions, + ) + + request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() + + # Handle the handshake response + original_timeout = self.gettimeout() + self.settimeout(deadline.to_timeout()) + try: + self.sendall(request) + response = self.recv(4) + except OSError as exc: + raise ServiceUnavailable( + f"Failed to read any data from server {resolved_address!r} " + f"after connected (deadline {deadline})" + ) from exc + finally: + self.settimeout(original_timeout) + data_size = len(response) + if data_size == 0: + # If no data is returned after a successful select + # response, the server has closed the connection + log.debug("[#%04X] S: ", local_port) + self.close() + raise ServiceUnavailable( + f"Connection to {resolved_address} closed without handshake " + "response" + ) + if data_size != 4: + # Some garbled data has been received + log.debug("[#%04X] S: @*#!", local_port) + self.close() + raise BoltProtocolError( + "Expected four byte Bolt handshake response from " + f"{resolved_address!r}, received {response!r} instead; " + "check for incorrect port number", + address=resolved_address, + ) + elif response == b"HTTP": + log.debug("[#%04X] S: ", local_port) + self.close() + raise ServiceUnavailable( + f"Cannot to connect to Bolt service on {resolved_address!r} " + "(looks like HTTP)" + ) + agreed_version = response[-1], response[-2] + log.debug( + "[#%04X] S: 0x%06X%02X", + local_port, + agreed_version[1], + agreed_version[0], + ) + return agreed_version, handshake, response + + @classmethod + def connect( + cls, + address, + *, + tcp_timeout, + deadline, + custom_resolver, + ssl_context, + keep_alive, + ): + """ + Connect and perform a handshake. + + Return a valid Connection object, assuming a protocol version can be + agreed. + """ + errors = [] + failed_addresses = [] + # Establish a connection to the host and port specified + # Catches refused connections see: + # https://docs.python.org/2/library/errno.html + + resolved_addresses = NetworkUtil.resolve_address( + addressing.Address(address), resolver=custom_resolver + ) + for resolved_address in resolved_addresses: + deadline_timeout = deadline.to_timeout() + if ( + deadline_timeout is not None + and deadline_timeout <= tcp_timeout + ): + tcp_timeout = deadline_timeout + s = None + try: + s = cls._connect_secure( + resolved_address, tcp_timeout, keep_alive, ssl_context + ) + return (s, *s._handshake(resolved_address, deadline)) + except (BoltError, DriverError, OSError) as error: + try: + local_port = s.getsockname()[1] + except (OSError, AttributeError, TypeError): + local_port = 0 + err_str = error.__class__.__name__ + if str(error): + err_str += ": " + str(error) + log.debug( + "[#%04X] S: %s %s", + local_port, + resolved_address, + err_str, + ) + if s: + cls.close_socket(s) + errors.append(error) + failed_addresses.append(resolved_address) + except asyncio.CancelledError: + try: + local_port = s.getsockname()[1] + except (OSError, AttributeError, TypeError): + local_port = 0 + log.debug( + "[#%04X] C: %s", local_port, resolved_address + ) + if s: + with suppress(OSError): + s.kill() + raise + except Exception: + if s: + cls.close_socket(s) + raise + address_strs = tuple(map(str, failed_addresses)) + if not errors: + raise ServiceUnavailable( + f"Couldn't connect to {address} (resolved to {address_strs})" + ) + else: + error_strs = "\n".join(map(str, errors)) + raise ServiceUnavailable( + f"Couldn't connect to {address} (resolved to {address_strs}):" + f"\n{error_strs}" + ) from errors[0] diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 5bbc50e85..a34db62ef 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -20,7 +20,7 @@ import neo4j.auth_management from neo4j._async.io import AsyncBolt -from neo4j._async_compat.network import AsyncBoltSocket +from neo4j._async.io._bolt_socket import AsyncBoltSocket from neo4j._exceptions import BoltHandshakeError from ...._async_compat import ( diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index f82d441dd..bab854f87 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -19,9 +19,9 @@ import pytest import neo4j.auth_management -from neo4j._async_compat.network import BoltSocket from neo4j._exceptions import BoltHandshakeError from neo4j._sync.io import Bolt +from neo4j._sync.io._bolt_socket import BoltSocket from ...._async_compat import ( mark_sync_test, From a1584b03286cface3b97e946823b1c176f9145ed Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 9 Oct 2024 18:17:15 +0200 Subject: [PATCH 02/18] WIP: Bolt handshake v2 TODO: tests --- src/neo4j/_async/io/_bolt.py | 30 +-- src/neo4j/_async/io/_bolt_socket.py | 239 +++++++++++++++++++----- src/neo4j/_sync/io/_bolt.py | 30 +-- src/neo4j/_sync/io/_bolt_socket.py | 239 +++++++++++++++++++----- tests/unit/async_/io/test_class_bolt.py | 49 +++-- tests/unit/sync/io/test_class_bolt.py | 49 +++-- 6 files changed, 448 insertions(+), 188 deletions(-) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 2ae6c0569..85aa5b439 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -274,7 +274,7 @@ def assert_notification_filtering_support(self): # [bolt-version-bump] search tag when changing bolt version support @classmethod - def protocol_handlers(cls, protocol_version=None): + def protocol_handlers(cls): """ Return a dictionary of available Bolt protocol handlers. @@ -293,7 +293,6 @@ def protocol_handlers(cls, protocol_version=None): # issues. from ._bolt3 import AsyncBolt3 from ._bolt4 import ( - AsyncBolt4x1, AsyncBolt4x2, AsyncBolt4x3, AsyncBolt4x4, @@ -308,10 +307,9 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x6, ) - handlers = { + return { AsyncBolt3.PROTOCOL_VERSION: AsyncBolt3, - # 4.0 unsupported because no space left in the handshake - AsyncBolt4x1.PROTOCOL_VERSION: AsyncBolt4x1, + # 4.0-4.1 unsupported because no space left in the handshake AsyncBolt4x2.PROTOCOL_VERSION: AsyncBolt4x2, AsyncBolt4x3.PROTOCOL_VERSION: AsyncBolt4x3, AsyncBolt4x4.PROTOCOL_VERSION: AsyncBolt4x4, @@ -324,19 +322,8 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6, } - if protocol_version is None: - return handlers - - if not isinstance(protocol_version, tuple): - raise TypeError("Protocol version must be specified as a tuple") - - if protocol_version in handlers: - return {protocol_version: handlers[protocol_version]} - - return {} - @classmethod - def version_list(cls, versions): + def version_list(cls, versions, limit=4): """ Return a list of supported protocol versions in order of preference. @@ -359,7 +346,7 @@ def version_list(cls, versions): result[-1][1][1] = version[1] continue result.append(Version(version[0], [version[1], version[1]])) - if len(result) == 4: + if len(result) >= limit: break return result @@ -374,8 +361,11 @@ def get_handshake(cls): supported_versions = sorted( cls.protocol_handlers().keys(), reverse=True ) - offered_versions = cls.version_list(supported_versions) - versions_bytes = (v.to_bytes() for v in offered_versions) + offered_versions = cls.version_list(supported_versions, limit=3) + versions_bytes = ( + Version(0xFF, 1).to_bytes(), # handshake v2 + *(v.to_bytes() for v in offered_versions), + ) return b"".join(versions_bytes).ljust(16, b"\x00") @classmethod diff --git a/src/neo4j/_async/io/_bolt_socket.py b/src/neo4j/_async/io/_bolt_socket.py index 0b453059d..f1dc715d3 100644 --- a/src/neo4j/_async/io/_bolt_socket.py +++ b/src/neo4j/_async/io/_bolt_socket.py @@ -14,9 +14,13 @@ # limitations under the License. +from __future__ import annotations + import asyncio +import dataclasses import logging import struct +import typing as t from contextlib import suppress from ... import addressing @@ -34,54 +38,120 @@ ) +if t.TYPE_CHECKING: + from ..._deadline import Deadline + + log = logging.getLogger("neo4j.io") -class AsyncBoltSocket(AsyncBoltSocketBase): - async def _handshake(self, resolved_address, deadline): - """ - Perform BOLT handshake. +@dataclasses.dataclass +class HandshakeCtx: + ctx: str + deadline: Deadline + local_port: int + resolved_address: addressing.ResolvedAddress + full_response: bytearray = dataclasses.field(default_factory=bytearray) - :param resolved_address: - :param deadline: Deadline for handshake - :returns: (version, client_handshake, server_response_data) - """ - local_port = self.getsockname()[1] +@dataclasses.dataclass +class BytesPrinter: + bytes: bytes | bytearray - # TODO: Optimize logging code - handshake = self.Bolt.get_handshake() - handshake = struct.unpack(">16B", handshake) - handshake = [handshake[i : i + 4] for i in range(0, len(handshake), 4)] + def __str__(self): + return f"0x{self.bytes.hex().upper()}" - supported_versions = [ - f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}" - for vx in handshake - ] +class AsyncBoltSocket(AsyncBoltSocketBase): + async def _parse_handshake_response_v1(self, ctx, response): + agreed_version = response[-1], response[-2] log.debug( - "[#%04X] C: 0x%08X", - local_port, - int.from_bytes(self.Bolt.MAGIC_PREAMBLE, byteorder="big"), + "[#%04X] S: 0x%06X%02X", + ctx.local_port, + agreed_version[1], + agreed_version[0], + ) + return agreed_version + + async def _parse_handshake_response_v2(self, ctx, response): + ctx.ctx = "handshake v2 offerings count" + num_offerings = await self._read_varint(ctx) + offerings = [] + for i in range(num_offerings): + ctx.ctx = f"handshake v2 offering {i}" + offering_response = await self._handshake_read(ctx, 4) + offering = offering_response[-1:-4:-1] + offerings.append(offering) + ctx.ctx = "handshake v2 capabilities" + _capabilities_offer = await self._read_varint(ctx) + + if log.getEffectiveLevel() >= logging.DEBUG: + log.debug( + "[#%04X] S: %s [%i] %s %s", + ctx.local_port, + BytesPrinter(response), + num_offerings, + " ".join(f"0x{vx[1]:06X}{vx[0]:02X}" for vx in offerings), + BytesPrinter(self._encode_varint(_capabilities_offer)), + ) + + supported_versions = sorted(self.Bolt.protocol_handlers().keys()) + chosen_version = 0, 0 + for v in supported_versions: + for offer_major, offer_minor, offer_range in offerings: + offer_max = (offer_major, offer_minor) + offer_min = (offer_major, offer_minor - offer_range) + if offer_min <= v <= offer_max: + chosen_version = v + break + + ctx.ctx = "handshake v2 chosen version" + await self._handshake_send( + ctx, bytes((0, 0, chosen_version[1], chosen_version[0])) ) + chosen_capabilities = 0 + capabilities = self._encode_varint(chosen_capabilities) + ctx.ctx = "handshake v2 chosen capabilities" log.debug( - "[#%04X] C: %s %s %s %s", - local_port, - *supported_versions, + "[#%04X] C: 0x%06X%02X %s", + ctx.local_port, + chosen_version[1], + chosen_version[0], + BytesPrinter(capabilities), ) + await self._handshake_send(ctx, b"\x00") - request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() + return chosen_version + + async def _read_varint(self, ctx): + next_byte = (await self._handshake_read(ctx, 1))[0] + res = next_byte & 0x7F + i = 0 + while next_byte & 0x80: + i += 1 + next_byte = (await self._handshake_read(ctx, 1))[0] + res += (next_byte & 0x7F) << (7 * i) + return res - # Handle the handshake response + @staticmethod + def _encode_varint(n): + res = bytearray() + while n >= 0x80: + res.append(n & 0x7F | 0x80) + n >>= 7 + res.append(n) + return res + + async def _handshake_read(self, ctx, n): original_timeout = self.gettimeout() - self.settimeout(deadline.to_timeout()) + self.settimeout(ctx.deadline.to_timeout()) try: - await self.sendall(request) - response = await self.recv(4) + response = await self.recv(n) + ctx.full_response.extend(response) except OSError as exc: raise ServiceUnavailable( - f"Failed to read any data from server {resolved_address!r} " - f"after connected (deadline {deadline})" + f"Failed to read {ctx.ctx} from server " + f"{ctx.resolved_address!r} (deadline {ctx.deadline})" ) from exc finally: self.settimeout(original_timeout) @@ -89,36 +159,113 @@ async def _handshake(self, resolved_address, deadline): if data_size == 0: # If no data is returned after a successful select # response, the server has closed the connection - log.debug("[#%04X] S: ", local_port) + log.debug("[#%04X] S: ", ctx.local_port) await self.close() raise ServiceUnavailable( - f"Connection to {resolved_address} closed without handshake " - "response" + f"Connection to {ctx.resolved_address} closed with incomplete " + f"handshake response" ) - if data_size != 4: + if data_size != n: # Some garbled data has been received - log.debug("[#%04X] S: @*#!", local_port) + log.debug("[#%04X] S: @*#!", ctx.local_port) await self.close() raise BoltProtocolError( - "Expected four byte Bolt handshake response from " - f"{resolved_address!r}, received {response!r} instead; " + f"Expected {ctx.ctx} from {ctx.resolved_address!r}, received " + f"{response!r} instead (so far {ctx.full_response!r}); " "check for incorrect port number", - address=resolved_address, + address=ctx.resolved_address, + ) + + return response + + async def _handshake_send(self, ctx, data): + original_timeout = self.gettimeout() + self.settimeout(ctx.deadline.to_timeout()) + try: + await self.sendall(data) + except OSError as exc: + raise ServiceUnavailable( + f"Failed to write {ctx.ctx} to server " + f"{ctx.resolved_address!r} (deadline {ctx.deadline})" + ) from exc + finally: + self.settimeout(original_timeout) + + async def _handshake(self, resolved_address, deadline): + """ + Perform BOLT handshake. + + :param resolved_address: + :param deadline: Deadline for handshake + + :returns: (version, client_handshake, server_response_data) + """ + local_port = self.getsockname()[1] + + if log.getEffectiveLevel() >= logging.DEBUG: + handshake = self.Bolt.get_handshake() + handshake = struct.unpack(">16B", handshake) + handshake = [ + handshake[i : i + 4] for i in range(0, len(handshake), 4) + ] + + supported_versions = [ + f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}" + for vx in handshake + ] + + log.debug( + "[#%04X] C: 0x%08X", + local_port, + int.from_bytes(self.Bolt.MAGIC_PREAMBLE, byteorder="big"), + ) + log.debug( + "[#%04X] C: %s %s %s %s", + local_port, + *supported_versions, ) - elif response == b"HTTP": + + request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() + + ctx = HandshakeCtx( + ctx="handshake opening", + deadline=deadline, + local_port=local_port, + resolved_address=resolved_address, + ) + + await self._handshake_send(ctx, request) + + ctx.ctx = "four byte Bolt handshake response" + response = await self._handshake_read(ctx, 4) + + if response == b"HTTP": log.debug("[#%04X] S: ", local_port) await self.close() raise ServiceUnavailable( f"Cannot to connect to Bolt service on {resolved_address!r} " "(looks like HTTP)" ) - agreed_version = response[-1], response[-2] - log.debug( - "[#%04X] S: 0x%06X%02X", - local_port, - agreed_version[1], - agreed_version[0], - ) + elif response[-1] == 0xFF: + # manifest style handshake + manifest_version = response[-2] + if manifest_version == 0x01: + agreed_version = await self._parse_handshake_response_v2( + ctx, + response, + ) + else: + raise BoltProtocolError( + "Unsupported Bolt handshake manifest version " + f"{manifest_version} received from {resolved_address!r}.", + address=resolved_address, + ) + else: + agreed_version = await self._parse_handshake_response_v1( + ctx, + response, + ) + return agreed_version, handshake, response @classmethod diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 9173813d9..126e2fd6d 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -274,7 +274,7 @@ def assert_notification_filtering_support(self): # [bolt-version-bump] search tag when changing bolt version support @classmethod - def protocol_handlers(cls, protocol_version=None): + def protocol_handlers(cls): """ Return a dictionary of available Bolt protocol handlers. @@ -293,7 +293,6 @@ def protocol_handlers(cls, protocol_version=None): # issues. from ._bolt3 import Bolt3 from ._bolt4 import ( - Bolt4x1, Bolt4x2, Bolt4x3, Bolt4x4, @@ -308,10 +307,9 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x6, ) - handlers = { + return { Bolt3.PROTOCOL_VERSION: Bolt3, - # 4.0 unsupported because no space left in the handshake - Bolt4x1.PROTOCOL_VERSION: Bolt4x1, + # 4.0-4.1 unsupported because no space left in the handshake Bolt4x2.PROTOCOL_VERSION: Bolt4x2, Bolt4x3.PROTOCOL_VERSION: Bolt4x3, Bolt4x4.PROTOCOL_VERSION: Bolt4x4, @@ -324,19 +322,8 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x6.PROTOCOL_VERSION: Bolt5x6, } - if protocol_version is None: - return handlers - - if not isinstance(protocol_version, tuple): - raise TypeError("Protocol version must be specified as a tuple") - - if protocol_version in handlers: - return {protocol_version: handlers[protocol_version]} - - return {} - @classmethod - def version_list(cls, versions): + def version_list(cls, versions, limit=4): """ Return a list of supported protocol versions in order of preference. @@ -359,7 +346,7 @@ def version_list(cls, versions): result[-1][1][1] = version[1] continue result.append(Version(version[0], [version[1], version[1]])) - if len(result) == 4: + if len(result) >= limit: break return result @@ -374,8 +361,11 @@ def get_handshake(cls): supported_versions = sorted( cls.protocol_handlers().keys(), reverse=True ) - offered_versions = cls.version_list(supported_versions) - versions_bytes = (v.to_bytes() for v in offered_versions) + offered_versions = cls.version_list(supported_versions, limit=3) + versions_bytes = ( + Version(0xFF, 1).to_bytes(), # handshake v2 + *(v.to_bytes() for v in offered_versions), + ) return b"".join(versions_bytes).ljust(16, b"\x00") @classmethod diff --git a/src/neo4j/_sync/io/_bolt_socket.py b/src/neo4j/_sync/io/_bolt_socket.py index 23cf0257b..506b9f18c 100644 --- a/src/neo4j/_sync/io/_bolt_socket.py +++ b/src/neo4j/_sync/io/_bolt_socket.py @@ -14,9 +14,13 @@ # limitations under the License. +from __future__ import annotations + import asyncio +import dataclasses import logging import struct +import typing as t from contextlib import suppress from ... import addressing @@ -34,54 +38,120 @@ ) +if t.TYPE_CHECKING: + from ..._deadline import Deadline + + log = logging.getLogger("neo4j.io") -class BoltSocket(BoltSocketBase): - def _handshake(self, resolved_address, deadline): - """ - Perform BOLT handshake. +@dataclasses.dataclass +class HandshakeCtx: + ctx: str + deadline: Deadline + local_port: int + resolved_address: addressing.ResolvedAddress + full_response: bytearray = dataclasses.field(default_factory=bytearray) - :param resolved_address: - :param deadline: Deadline for handshake - :returns: (version, client_handshake, server_response_data) - """ - local_port = self.getsockname()[1] +@dataclasses.dataclass +class BytesPrinter: + bytes: bytes | bytearray - # TODO: Optimize logging code - handshake = self.Bolt.get_handshake() - handshake = struct.unpack(">16B", handshake) - handshake = [handshake[i : i + 4] for i in range(0, len(handshake), 4)] + def __str__(self): + return f"0x{self.bytes.hex().upper()}" - supported_versions = [ - f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}" - for vx in handshake - ] +class BoltSocket(BoltSocketBase): + def _parse_handshake_response_v1(self, ctx, response): + agreed_version = response[-1], response[-2] log.debug( - "[#%04X] C: 0x%08X", - local_port, - int.from_bytes(self.Bolt.MAGIC_PREAMBLE, byteorder="big"), + "[#%04X] S: 0x%06X%02X", + ctx.local_port, + agreed_version[1], + agreed_version[0], + ) + return agreed_version + + def _parse_handshake_response_v2(self, ctx, response): + ctx.ctx = "handshake v2 offerings count" + num_offerings = self._read_varint(ctx) + offerings = [] + for i in range(num_offerings): + ctx.ctx = f"handshake v2 offering {i}" + offering_response = self._handshake_read(ctx, 4) + offering = offering_response[-1:-4:-1] + offerings.append(offering) + ctx.ctx = "handshake v2 capabilities" + _capabilities_offer = self._read_varint(ctx) + + if log.getEffectiveLevel() >= logging.DEBUG: + log.debug( + "[#%04X] S: %s [%i] %s %s", + ctx.local_port, + BytesPrinter(response), + num_offerings, + " ".join(f"0x{vx[1]:06X}{vx[0]:02X}" for vx in offerings), + BytesPrinter(self._encode_varint(_capabilities_offer)), + ) + + supported_versions = sorted(self.Bolt.protocol_handlers().keys()) + chosen_version = 0, 0 + for v in supported_versions: + for offer_major, offer_minor, offer_range in offerings: + offer_max = (offer_major, offer_minor) + offer_min = (offer_major, offer_minor - offer_range) + if offer_min <= v <= offer_max: + chosen_version = v + break + + ctx.ctx = "handshake v2 chosen version" + self._handshake_send( + ctx, bytes((0, 0, chosen_version[1], chosen_version[0])) ) + chosen_capabilities = 0 + capabilities = self._encode_varint(chosen_capabilities) + ctx.ctx = "handshake v2 chosen capabilities" log.debug( - "[#%04X] C: %s %s %s %s", - local_port, - *supported_versions, + "[#%04X] C: 0x%06X%02X %s", + ctx.local_port, + chosen_version[1], + chosen_version[0], + BytesPrinter(capabilities), ) + self._handshake_send(ctx, b"\x00") - request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() + return chosen_version + + def _read_varint(self, ctx): + next_byte = (self._handshake_read(ctx, 1))[0] + res = next_byte & 0x7F + i = 0 + while next_byte & 0x80: + i += 1 + next_byte = (self._handshake_read(ctx, 1))[0] + res += (next_byte & 0x7F) << (7 * i) + return res - # Handle the handshake response + @staticmethod + def _encode_varint(n): + res = bytearray() + while n >= 0x80: + res.append(n & 0x7F | 0x80) + n >>= 7 + res.append(n) + return res + + def _handshake_read(self, ctx, n): original_timeout = self.gettimeout() - self.settimeout(deadline.to_timeout()) + self.settimeout(ctx.deadline.to_timeout()) try: - self.sendall(request) - response = self.recv(4) + response = self.recv(n) + ctx.full_response.extend(response) except OSError as exc: raise ServiceUnavailable( - f"Failed to read any data from server {resolved_address!r} " - f"after connected (deadline {deadline})" + f"Failed to read {ctx.ctx} from server " + f"{ctx.resolved_address!r} (deadline {ctx.deadline})" ) from exc finally: self.settimeout(original_timeout) @@ -89,36 +159,113 @@ def _handshake(self, resolved_address, deadline): if data_size == 0: # If no data is returned after a successful select # response, the server has closed the connection - log.debug("[#%04X] S: ", local_port) + log.debug("[#%04X] S: ", ctx.local_port) self.close() raise ServiceUnavailable( - f"Connection to {resolved_address} closed without handshake " - "response" + f"Connection to {ctx.resolved_address} closed with incomplete " + f"handshake response" ) - if data_size != 4: + if data_size != n: # Some garbled data has been received - log.debug("[#%04X] S: @*#!", local_port) + log.debug("[#%04X] S: @*#!", ctx.local_port) self.close() raise BoltProtocolError( - "Expected four byte Bolt handshake response from " - f"{resolved_address!r}, received {response!r} instead; " + f"Expected {ctx.ctx} from {ctx.resolved_address!r}, received " + f"{response!r} instead (so far {ctx.full_response!r}); " "check for incorrect port number", - address=resolved_address, + address=ctx.resolved_address, + ) + + return response + + def _handshake_send(self, ctx, data): + original_timeout = self.gettimeout() + self.settimeout(ctx.deadline.to_timeout()) + try: + self.sendall(data) + except OSError as exc: + raise ServiceUnavailable( + f"Failed to write {ctx.ctx} to server " + f"{ctx.resolved_address!r} (deadline {ctx.deadline})" + ) from exc + finally: + self.settimeout(original_timeout) + + def _handshake(self, resolved_address, deadline): + """ + Perform BOLT handshake. + + :param resolved_address: + :param deadline: Deadline for handshake + + :returns: (version, client_handshake, server_response_data) + """ + local_port = self.getsockname()[1] + + if log.getEffectiveLevel() >= logging.DEBUG: + handshake = self.Bolt.get_handshake() + handshake = struct.unpack(">16B", handshake) + handshake = [ + handshake[i : i + 4] for i in range(0, len(handshake), 4) + ] + + supported_versions = [ + f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}" + for vx in handshake + ] + + log.debug( + "[#%04X] C: 0x%08X", + local_port, + int.from_bytes(self.Bolt.MAGIC_PREAMBLE, byteorder="big"), + ) + log.debug( + "[#%04X] C: %s %s %s %s", + local_port, + *supported_versions, ) - elif response == b"HTTP": + + request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() + + ctx = HandshakeCtx( + ctx="handshake opening", + deadline=deadline, + local_port=local_port, + resolved_address=resolved_address, + ) + + self._handshake_send(ctx, request) + + ctx.ctx = "four byte Bolt handshake response" + response = self._handshake_read(ctx, 4) + + if response == b"HTTP": log.debug("[#%04X] S: ", local_port) self.close() raise ServiceUnavailable( f"Cannot to connect to Bolt service on {resolved_address!r} " "(looks like HTTP)" ) - agreed_version = response[-1], response[-2] - log.debug( - "[#%04X] S: 0x%06X%02X", - local_port, - agreed_version[1], - agreed_version[0], - ) + elif response[-1] == 0xFF: + # manifest style handshake + manifest_version = response[-2] + if manifest_version == 0x01: + agreed_version = self._parse_handshake_response_v2( + ctx, + response, + ) + else: + raise BoltProtocolError( + "Unsupported Bolt handshake manifest version " + f"{manifest_version} received from {resolved_address!r}.", + address=resolved_address, + ) + else: + agreed_version = self._parse_handshake_response_v1( + ctx, + response, + ) + return agreed_version, handshake, response @classmethod diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index a34db62ef..5ddc47ba1 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -52,38 +52,31 @@ def test_class_method_protocol_handlers(): @pytest.mark.parametrize( ("test_input", "expected"), [ - ((0, 0), 0), - ((1, 0), 0), - ((2, 0), 0), - ((3, 0), 1), - ((4, 0), 0), - ((4, 1), 1), - ((4, 2), 1), - ((4, 3), 1), - ((4, 4), 1), - ((5, 0), 1), - ((5, 1), 1), - ((5, 2), 1), - ((5, 3), 1), - ((5, 4), 1), - ((5, 5), 1), - ((5, 6), 1), - ((5, 7), 0), - ((6, 0), 0), + ((0, 0), False), + ((1, 0), False), + ((2, 0), False), + ((3, 0), True), + ((4, 0), False), + ((4, 1), True), + ((4, 2), True), + ((4, 3), True), + ((4, 4), True), + ((5, 0), True), + ((5, 1), True), + ((5, 2), True), + ((5, 3), True), + ((5, 4), True), + ((5, 5), True), + ((5, 6), True), + ((5, 7), False), + ((6, 0), False), ], ) def test_class_method_protocol_handlers_with_protocol_version( test_input, expected ): - protocol_handlers = AsyncBolt.protocol_handlers( - protocol_version=test_input - ) - assert len(protocol_handlers) == expected - - -def test_class_method_protocol_handlers_with_invalid_protocol_version(): - with pytest.raises(TypeError): - AsyncBolt.protocol_handlers(protocol_version=2) + protocol_handlers = AsyncBolt.protocol_handlers() + assert (test_input in protocol_handlers) == expected # [bolt-version-bump] search tag when changing bolt version support @@ -91,7 +84,7 @@ def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() assert ( handshake - == b"\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x00\x01\xff\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x00\x03" ) diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index bab854f87..fdec34252 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -52,38 +52,31 @@ def test_class_method_protocol_handlers(): @pytest.mark.parametrize( ("test_input", "expected"), [ - ((0, 0), 0), - ((1, 0), 0), - ((2, 0), 0), - ((3, 0), 1), - ((4, 0), 0), - ((4, 1), 1), - ((4, 2), 1), - ((4, 3), 1), - ((4, 4), 1), - ((5, 0), 1), - ((5, 1), 1), - ((5, 2), 1), - ((5, 3), 1), - ((5, 4), 1), - ((5, 5), 1), - ((5, 6), 1), - ((5, 7), 0), - ((6, 0), 0), + ((0, 0), False), + ((1, 0), False), + ((2, 0), False), + ((3, 0), True), + ((4, 0), False), + ((4, 1), True), + ((4, 2), True), + ((4, 3), True), + ((4, 4), True), + ((5, 0), True), + ((5, 1), True), + ((5, 2), True), + ((5, 3), True), + ((5, 4), True), + ((5, 5), True), + ((5, 6), True), + ((5, 7), False), + ((6, 0), False), ], ) def test_class_method_protocol_handlers_with_protocol_version( test_input, expected ): - protocol_handlers = Bolt.protocol_handlers( - protocol_version=test_input - ) - assert len(protocol_handlers) == expected - - -def test_class_method_protocol_handlers_with_invalid_protocol_version(): - with pytest.raises(TypeError): - Bolt.protocol_handlers(protocol_version=2) + protocol_handlers = Bolt.protocol_handlers() + assert (test_input in protocol_handlers) == expected # [bolt-version-bump] search tag when changing bolt version support @@ -91,7 +84,7 @@ def test_class_method_get_handshake(): handshake = Bolt.get_handshake() assert ( handshake - == b"\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x00\x01\xff\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x00\x03" ) From cdf1a8e8e232ca92f631cad480d319f8d9b7f806 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 10 Oct 2024 18:10:34 +0200 Subject: [PATCH 03/18] Fix handshake v2 logging --- src/neo4j/_async/io/_bolt_socket.py | 4 +++- src/neo4j/_sync/io/_bolt_socket.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/neo4j/_async/io/_bolt_socket.py b/src/neo4j/_async/io/_bolt_socket.py index f1dc715d3..a5c0909a8 100644 --- a/src/neo4j/_async/io/_bolt_socket.py +++ b/src/neo4j/_async/io/_bolt_socket.py @@ -91,7 +91,9 @@ async def _parse_handshake_response_v2(self, ctx, response): ctx.local_port, BytesPrinter(response), num_offerings, - " ".join(f"0x{vx[1]:06X}{vx[0]:02X}" for vx in offerings), + " ".join( + f"0x{vx[2]:04X}{vx[1]:02X}{vx[0]:02X}" for vx in offerings + ), BytesPrinter(self._encode_varint(_capabilities_offer)), ) diff --git a/src/neo4j/_sync/io/_bolt_socket.py b/src/neo4j/_sync/io/_bolt_socket.py index 506b9f18c..774669a5d 100644 --- a/src/neo4j/_sync/io/_bolt_socket.py +++ b/src/neo4j/_sync/io/_bolt_socket.py @@ -91,7 +91,9 @@ def _parse_handshake_response_v2(self, ctx, response): ctx.local_port, BytesPrinter(response), num_offerings, - " ".join(f"0x{vx[1]:06X}{vx[0]:02X}" for vx in offerings), + " ".join( + f"0x{vx[2]:04X}{vx[1]:02X}{vx[0]:02X}" for vx in offerings + ), BytesPrinter(self._encode_varint(_capabilities_offer)), ) From 74863c45eae24bc723157ed28a18c859a90c2b82 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 10 Oct 2024 18:11:26 +0200 Subject: [PATCH 04/18] Fix unit tests --- src/neo4j/_async/io/_bolt.py | 10 +++++----- src/neo4j/_async/io/_bolt_socket.py | 5 ++++- src/neo4j/_sync/io/_bolt.py | 10 +++++----- src/neo4j/_sync/io/_bolt_socket.py | 5 ++++- tests/unit/async_/io/test_class_bolt.py | 7 ++++--- tests/unit/sync/io/test_class_bolt.py | 7 ++++--- 6 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 85aa5b439..a997393dc 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -478,12 +478,12 @@ async def open( elif protocol_version == (4, 2): from ._bolt4 import AsyncBolt4x2 bolt_cls = AsyncBolt4x2 - elif protocol_version == (4, 1): - from ._bolt4 import AsyncBolt4x1 - bolt_cls = AsyncBolt4x1 - # Implementation for 4.0 exists, but there was no space left in the + # Implementations for exist, but there was no space left in the # handshake to offer this version to the server. Hence, the server - # should never request us to speak bolt 4.0. + # should never request us to speak these bolt versions. + # elif protocol_version == (4, 1): + # from ._bolt4 import AsyncBolt4x1 + # bolt_cls = AsyncBolt4x1 # elif protocol_version == (4, 0): # from ._bolt4 import AsyncBolt4x0 # bolt_cls = AsyncBolt4x0 diff --git a/src/neo4j/_async/io/_bolt_socket.py b/src/neo4j/_async/io/_bolt_socket.py index a5c0909a8..6ca6eaa0c 100644 --- a/src/neo4j/_async/io/_bolt_socket.py +++ b/src/neo4j/_async/io/_bolt_socket.py @@ -308,7 +308,10 @@ async def connect( s = await cls._connect_secure( resolved_address, tcp_timeout, keep_alive, ssl_context ) - return (s, *await s._handshake(resolved_address, deadline)) + agreed_version, handshake, response = await s._handshake( + resolved_address, deadline + ) + return s, agreed_version, handshake, response except (BoltError, DriverError, OSError) as error: try: local_port = s.getsockname()[1] diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 126e2fd6d..e2b232d23 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -478,12 +478,12 @@ def open( elif protocol_version == (4, 2): from ._bolt4 import Bolt4x2 bolt_cls = Bolt4x2 - elif protocol_version == (4, 1): - from ._bolt4 import Bolt4x1 - bolt_cls = Bolt4x1 - # Implementation for 4.0 exists, but there was no space left in the + # Implementations for exist, but there was no space left in the # handshake to offer this version to the server. Hence, the server - # should never request us to speak bolt 4.0. + # should never request us to speak these bolt versions. + # elif protocol_version == (4, 1): + # from ._bolt4 import AsyncBolt4x1 + # bolt_cls = AsyncBolt4x1 # elif protocol_version == (4, 0): # from ._bolt4 import AsyncBolt4x0 # bolt_cls = AsyncBolt4x0 diff --git a/src/neo4j/_sync/io/_bolt_socket.py b/src/neo4j/_sync/io/_bolt_socket.py index 774669a5d..160b12e3c 100644 --- a/src/neo4j/_sync/io/_bolt_socket.py +++ b/src/neo4j/_sync/io/_bolt_socket.py @@ -308,7 +308,10 @@ def connect( s = cls._connect_secure( resolved_address, tcp_timeout, keep_alive, ssl_context ) - return (s, *s._handshake(resolved_address, deadline)) + agreed_version, handshake, response = s._handshake( + resolved_address, deadline + ) + return s, agreed_version, handshake, response except (BoltError, DriverError, OSError) as error: try: local_port = s.getsockname()[1] diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 5ddc47ba1..5091a24b5 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -37,7 +37,7 @@ def test_class_method_protocol_handlers(): # fmt: off expected_handlers = { (3, 0), - (4, 1), (4, 2), (4, 3), (4, 4), + (4, 2), (4, 3), (4, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), } # fmt: on @@ -57,7 +57,7 @@ def test_class_method_protocol_handlers(): ((2, 0), False), ((3, 0), True), ((4, 0), False), - ((4, 1), True), + ((4, 1), False), ((4, 2), True), ((4, 3), True), ((4, 4), True), @@ -171,6 +171,7 @@ async def test_version_negotiation( (0, 0), (2, 0), (4, 0), + (4, 1), (3, 1), (5, 7), (6, 0), @@ -179,7 +180,7 @@ async def test_version_negotiation( @mark_async_test async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( - "('3.0', '4.1', '4.2', '4.3', '4.4', " + "('3.0', '4.2', '4.3', '4.4', " "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6')" ) diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index fdec34252..73b91030d 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -37,7 +37,7 @@ def test_class_method_protocol_handlers(): # fmt: off expected_handlers = { (3, 0), - (4, 1), (4, 2), (4, 3), (4, 4), + (4, 2), (4, 3), (4, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), } # fmt: on @@ -57,7 +57,7 @@ def test_class_method_protocol_handlers(): ((2, 0), False), ((3, 0), True), ((4, 0), False), - ((4, 1), True), + ((4, 1), False), ((4, 2), True), ((4, 3), True), ((4, 4), True), @@ -171,6 +171,7 @@ def test_version_negotiation( (0, 0), (2, 0), (4, 0), + (4, 1), (3, 1), (5, 7), (6, 0), @@ -179,7 +180,7 @@ def test_version_negotiation( @mark_sync_test def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( - "('3.0', '4.1', '4.2', '4.3', '4.4', " + "('3.0', '4.2', '4.3', '4.4', " "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6')" ) From b23dca5de45804975f3980d28bce45d4dac311dd Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 10 Oct 2024 18:15:06 +0200 Subject: [PATCH 05/18] Fix socket resource leak --- .../_async_compat/network/_bolt_socket.py | 117 +++++++++--------- 1 file changed, 61 insertions(+), 56 deletions(-) diff --git a/src/neo4j/_async_compat/network/_bolt_socket.py b/src/neo4j/_async_compat/network/_bolt_socket.py index ec14efa24..879b9344b 100644 --- a/src/neo4j/_async_compat/network/_bolt_socket.py +++ b/src/neo4j/_async_compat/network/_bolt_socket.py @@ -386,65 +386,70 @@ def _connect_secure( s = None # The socket try: - if len(resolved_address) == 2: - s = socket(AF_INET) - elif len(resolved_address) == 4: - s = socket(AF_INET6) - else: - raise ValueError(f"Unsupported address {resolved_address!r}") - t = s.gettimeout() - if timeout: - s.settimeout(timeout) - log.debug("[#0000] C: %s", resolved_address) - s.connect(resolved_address) - s.settimeout(t) - keep_alive = 1 if keep_alive else 0 - s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) - except SocketTimeout: - log.debug("[#0000] S: %s", resolved_address) - log.debug("[#0000] C: %s", resolved_address) - cls._kill_raw_socket(s) - raise ServiceUnavailable( - "Timed out trying to establish connection to " - f"{resolved_address!r}" - ) from None - except Exception as error: - log.debug( - "[#0000] S: %s %s", - type(error).__name__, - " ".join(map(repr, error.args)), - ) - log.debug("[#0000] C: %s", resolved_address) - cls._kill_raw_socket(s) - if isinstance(error, OSError): + try: + if len(resolved_address) == 2: + s = socket(AF_INET) + elif len(resolved_address) == 4: + s = socket(AF_INET6) + else: + raise ValueError( + f"Unsupported address {resolved_address!r}" + ) + t = s.gettimeout() + if timeout: + s.settimeout(timeout) + log.debug("[#0000] C: %s", resolved_address) + s.connect(resolved_address) + s.settimeout(t) + keep_alive = 1 if keep_alive else 0 + s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) + except SocketTimeout: + log.debug("[#0000] S: %s", resolved_address) raise ServiceUnavailable( - f"Failed to establish connection to {resolved_address!r} " - f"(reason {error})" - ) from error - raise + "Timed out trying to establish connection to " + f"{resolved_address!r}" + ) from None + except Exception as error: + log.debug( + "[#0000] S: %s %s", + type(error).__name__, + " ".join(map(repr, error.args)), + ) + if isinstance(error, OSError): + raise ServiceUnavailable( + "Failed to establish connection to " + f"{resolved_address!r} (reason {error})" + ) from error + raise - local_port = s.getsockname()[1] - # Secure the connection if an SSL context has been provided - if ssl_context: - hostname = resolved_address._host_name or None - sni_host = hostname if HAS_SNI and hostname else None - log.debug("[#%04X] C: %s", local_port, hostname) - try: - s = ssl_context.wrap_socket(s, server_hostname=sni_host) - except (OSError, SSLError, CertificateError) as cause: - cls._kill_raw_socket(s) - raise BoltSecurityError( - message="Failed to establish encrypted connection.", - address=(hostname, local_port), - ) from cause - # Check that the server provides a certificate - der_encoded_server_certificate = s.getpeercert(binary_form=True) - if der_encoded_server_certificate is None: - raise BoltProtocolError( - "When using an encrypted socket, the server should always " - "provide a certificate", - address=(hostname, local_port), + local_port = s.getsockname()[1] + # Secure the connection if an SSL context has been provided + if ssl_context: + hostname = resolved_address._host_name or None + sni_host = hostname if HAS_SNI and hostname else None + log.debug("[#%04X] C: %s", local_port, hostname) + try: + s = ssl_context.wrap_socket(s, server_hostname=sni_host) + except (OSError, SSLError, CertificateError) as cause: + raise BoltSecurityError( + message="Failed to establish encrypted connection.", + address=(hostname, local_port), + ) from cause + # Check that the server provides a certificate + der_encoded_server_certificate = s.getpeercert( + binary_form=True ) + if der_encoded_server_certificate is None: + raise BoltProtocolError( + "When using an encrypted socket, the server should" + "always provide a certificate", + address=(hostname, local_port), + ) + except Exception: + if s is not None: + log.debug("[#0000] C: %s", resolved_address) + cls._kill_raw_socket(s) + raise return cls(s) From 0b5c785faf17d7037766da02c1b4b708d148b5c9 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 11 Oct 2024 10:03:13 +0200 Subject: [PATCH 06/18] WIP --- src/neo4j/_async/io/__init__.py | 5 + src/neo4j/_async/io/_bolt.py | 168 +++++------------------- src/neo4j/_sync/io/__init__.py | 5 + src/neo4j/_sync/io/_bolt.py | 168 +++++------------------- src/neo4j/api.py | 8 +- testkit/_common.py | 2 +- testkitbackend/test_config.json | 1 - tests/unit/async_/io/test_class_bolt.py | 42 +++--- tests/unit/common/test_exceptions.py | 2 +- tests/unit/sync/io/test_class_bolt.py | 42 +++--- tox.ini | 8 +- 11 files changed, 124 insertions(+), 327 deletions(-) diff --git a/src/neo4j/_async/io/__init__.py b/src/neo4j/_async/io/__init__.py index 3571ad94e..b69f8377e 100644 --- a/src/neo4j/_async/io/__init__.py +++ b/src/neo4j/_async/io/__init__.py @@ -31,6 +31,11 @@ ] +from . import ( # noqa - imports needed to register protocol handlers + _bolt3, + _bolt4, + _bolt5, +) from ._bolt import AsyncBolt from ._common import ( check_supported_server_product, diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index a997393dc..3c5271467 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -60,6 +60,8 @@ if t.TYPE_CHECKING: + import typing_extensions as te + from ..._api import TelemetryAPI @@ -272,83 +274,31 @@ def assert_notification_filtering_support(self): f"{self.server_info.agent!r}" ) - # [bolt-version-bump] search tag when changing bolt version support - @classmethod - def protocol_handlers(cls): - """ - Return a dictionary of available Bolt protocol handlers. - - The handlers are keyed by version tuple. If an explicit protocol - version is provided, the dictionary will contain either zero or one - items, depending on whether that version is supported. If no protocol - version is provided, all available versions will be returned. - - :param protocol_version: tuple identifying a specific protocol - version (e.g. (3, 5)) or None - :returns: dictionary of version tuple to handler class for all - relevant and supported protocol versions - :raise TypeError: if protocol version is not passed in a tuple - """ - # Carry out Bolt subclass imports locally to avoid circular dependency - # issues. - from ._bolt3 import AsyncBolt3 - from ._bolt4 import ( - AsyncBolt4x2, - AsyncBolt4x3, - AsyncBolt4x4, - ) - from ._bolt5 import ( - AsyncBolt5x0, - AsyncBolt5x1, - AsyncBolt5x2, - AsyncBolt5x3, - AsyncBolt5x4, - AsyncBolt5x5, - AsyncBolt5x6, - ) - - return { - AsyncBolt3.PROTOCOL_VERSION: AsyncBolt3, - # 4.0-4.1 unsupported because no space left in the handshake - AsyncBolt4x2.PROTOCOL_VERSION: AsyncBolt4x2, - AsyncBolt4x3.PROTOCOL_VERSION: AsyncBolt4x3, - AsyncBolt4x4.PROTOCOL_VERSION: AsyncBolt4x4, - AsyncBolt5x0.PROTOCOL_VERSION: AsyncBolt5x0, - AsyncBolt5x1.PROTOCOL_VERSION: AsyncBolt5x1, - AsyncBolt5x2.PROTOCOL_VERSION: AsyncBolt5x2, - AsyncBolt5x3.PROTOCOL_VERSION: AsyncBolt5x3, - AsyncBolt5x4.PROTOCOL_VERSION: AsyncBolt5x4, - AsyncBolt5x5.PROTOCOL_VERSION: AsyncBolt5x5, - AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6, - } - - @classmethod - def version_list(cls, versions, limit=4): - """ - Return a list of supported protocol versions in order of preference. + protocol_handlers: t.ClassVar[dict[Version, type[AsyncBolt]]] = {} - The number of protocol versions (or ranges) returned is limited to 4. - """ - # In fact, 4.3 is the fist version to support ranges. However, the - # range support got backported to 4.2. But even if the server is too - # old to have the backport, negotiating BOLT 4.1 is no problem as it's - # equivalent to 4.2 - first_with_range_support = Version(4, 2) - result = [] - for version in versions: - if ( - result - and version >= first_with_range_support - and result[-1][0] == version[0] - and result[-1][1][1] == version[1] + 1 - ): - # can use range to encompass this version - result[-1][1][1] = version[1] - continue - result.append(Version(version[0], [version[1], version[1]])) - if len(result) >= limit: - break - return result + def __init_subclass__(cls: type[te.Self], **kwargs: t.Any) -> None: + protocol_version = cls.PROTOCOL_VERSION + if protocol_version is None: + raise ValueError( + "AsyncBolt subclasses must define PROTOCOL_VERSION" + ) + if not ( + isinstance(protocol_version, Version) + and len(protocol_version) == 2 + and all(isinstance(i, int) for i in protocol_version) + ): + raise TypeError( + "PROTOCOL_VERSION must be a 2-tuple of integers, not " + f"{protocol_version!r}" + ) + if protocol_version in AsyncBolt.protocol_handlers: + cls_conflict = AsyncBolt.protocol_handlers[protocol_version] + raise TypeError( + f"Multiple classes for the same protocol version " + f"{protocol_version}: {cls}, {cls_conflict}" + ) + cls.protocol_handlers[protocol_version] = cls + super().__init_subclass__(**kwargs) @classmethod def get_handshake(cls): @@ -358,15 +308,9 @@ def get_handshake(cls): The length is 16 bytes as specified in the Bolt version negotiation. :returns: bytes """ - supported_versions = sorted( - cls.protocol_handlers().keys(), reverse=True - ) - offered_versions = cls.version_list(supported_versions, limit=3) - versions_bytes = ( - Version(0xFF, 1).to_bytes(), # handshake v2 - *(v.to_bytes() for v in offered_versions), + return ( + b"\x00\x00\x01\xff\x00\x06\x06\x05\x00\x04\x04\x04\x00\x00\x00\x03" ) - return b"".join(versions_bytes).ljust(16, b"\x00") @classmethod async def ping(cls, address, *, deadline=None, pool_config=None): @@ -442,64 +386,16 @@ async def open( ) pool_config.protocol_version = protocol_version - - # Carry out Bolt subclass imports locally to avoid circular dependency - # issues. - - # avoid new lines after imports for better readability and conciseness - # fmt: off - if protocol_version == (5, 6): - from ._bolt5 import AsyncBolt5x6 - bolt_cls = AsyncBolt5x6 - elif protocol_version == (5, 5): - from ._bolt5 import AsyncBolt5x5 - bolt_cls = AsyncBolt5x5 - elif protocol_version == (5, 4): - from ._bolt5 import AsyncBolt5x4 - bolt_cls = AsyncBolt5x4 - elif protocol_version == (5, 3): - from ._bolt5 import AsyncBolt5x3 - bolt_cls = AsyncBolt5x3 - elif protocol_version == (5, 2): - from ._bolt5 import AsyncBolt5x2 - bolt_cls = AsyncBolt5x2 - elif protocol_version == (5, 1): - from ._bolt5 import AsyncBolt5x1 - bolt_cls = AsyncBolt5x1 - elif protocol_version == (5, 0): - from ._bolt5 import AsyncBolt5x0 - bolt_cls = AsyncBolt5x0 - elif protocol_version == (4, 4): - from ._bolt4 import AsyncBolt4x4 - bolt_cls = AsyncBolt4x4 - elif protocol_version == (4, 3): - from ._bolt4 import AsyncBolt4x3 - bolt_cls = AsyncBolt4x3 - elif protocol_version == (4, 2): - from ._bolt4 import AsyncBolt4x2 - bolt_cls = AsyncBolt4x2 - # Implementations for exist, but there was no space left in the - # handshake to offer this version to the server. Hence, the server - # should never request us to speak these bolt versions. - # elif protocol_version == (4, 1): - # from ._bolt4 import AsyncBolt4x1 - # bolt_cls = AsyncBolt4x1 - # elif protocol_version == (4, 0): - # from ._bolt4 import AsyncBolt4x0 - # bolt_cls = AsyncBolt4x0 - elif protocol_version == (3, 0): - from ._bolt3 import AsyncBolt3 - bolt_cls = AsyncBolt3 - # fmt: on - else: + protocol_handlers = AsyncBolt.protocol_handlers + bolt_cls = protocol_handlers.get(protocol_version) + if bolt_cls is None: log.debug("[#%04X] C: ", s.getsockname()[1]) await AsyncBoltSocket.close_socket(s) - supported_versions = cls.protocol_handlers().keys() raise BoltHandshakeError( "The neo4j server does not support communication with this " "driver. This driver has support for Bolt protocols " - f"{tuple(map(str, supported_versions))}.", + f"{tuple(map(str, AsyncBolt.protocol_handlers.keys()))}.", address=address, request_data=handshake, response_data=data, diff --git a/src/neo4j/_sync/io/__init__.py b/src/neo4j/_sync/io/__init__.py index a1833c743..4f9009d88 100644 --- a/src/neo4j/_sync/io/__init__.py +++ b/src/neo4j/_sync/io/__init__.py @@ -31,6 +31,11 @@ ] +from . import ( # noqa - imports needed to register protocol handlers + _bolt3, + _bolt4, + _bolt5, +) from ._bolt import Bolt from ._common import ( check_supported_server_product, diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index e2b232d23..46ac83a57 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -60,6 +60,8 @@ if t.TYPE_CHECKING: + import typing_extensions as te + from ..._api import TelemetryAPI @@ -272,83 +274,31 @@ def assert_notification_filtering_support(self): f"{self.server_info.agent!r}" ) - # [bolt-version-bump] search tag when changing bolt version support - @classmethod - def protocol_handlers(cls): - """ - Return a dictionary of available Bolt protocol handlers. - - The handlers are keyed by version tuple. If an explicit protocol - version is provided, the dictionary will contain either zero or one - items, depending on whether that version is supported. If no protocol - version is provided, all available versions will be returned. - - :param protocol_version: tuple identifying a specific protocol - version (e.g. (3, 5)) or None - :returns: dictionary of version tuple to handler class for all - relevant and supported protocol versions - :raise TypeError: if protocol version is not passed in a tuple - """ - # Carry out Bolt subclass imports locally to avoid circular dependency - # issues. - from ._bolt3 import Bolt3 - from ._bolt4 import ( - Bolt4x2, - Bolt4x3, - Bolt4x4, - ) - from ._bolt5 import ( - Bolt5x0, - Bolt5x1, - Bolt5x2, - Bolt5x3, - Bolt5x4, - Bolt5x5, - Bolt5x6, - ) - - return { - Bolt3.PROTOCOL_VERSION: Bolt3, - # 4.0-4.1 unsupported because no space left in the handshake - Bolt4x2.PROTOCOL_VERSION: Bolt4x2, - Bolt4x3.PROTOCOL_VERSION: Bolt4x3, - Bolt4x4.PROTOCOL_VERSION: Bolt4x4, - Bolt5x0.PROTOCOL_VERSION: Bolt5x0, - Bolt5x1.PROTOCOL_VERSION: Bolt5x1, - Bolt5x2.PROTOCOL_VERSION: Bolt5x2, - Bolt5x3.PROTOCOL_VERSION: Bolt5x3, - Bolt5x4.PROTOCOL_VERSION: Bolt5x4, - Bolt5x5.PROTOCOL_VERSION: Bolt5x5, - Bolt5x6.PROTOCOL_VERSION: Bolt5x6, - } - - @classmethod - def version_list(cls, versions, limit=4): - """ - Return a list of supported protocol versions in order of preference. + protocol_handlers: t.ClassVar[dict[Version, type[Bolt]]] = {} - The number of protocol versions (or ranges) returned is limited to 4. - """ - # In fact, 4.3 is the fist version to support ranges. However, the - # range support got backported to 4.2. But even if the server is too - # old to have the backport, negotiating BOLT 4.1 is no problem as it's - # equivalent to 4.2 - first_with_range_support = Version(4, 2) - result = [] - for version in versions: - if ( - result - and version >= first_with_range_support - and result[-1][0] == version[0] - and result[-1][1][1] == version[1] + 1 - ): - # can use range to encompass this version - result[-1][1][1] = version[1] - continue - result.append(Version(version[0], [version[1], version[1]])) - if len(result) >= limit: - break - return result + def __init_subclass__(cls: type[te.Self], **kwargs: t.Any) -> None: + protocol_version = cls.PROTOCOL_VERSION + if protocol_version is None: + raise ValueError( + "Bolt subclasses must define PROTOCOL_VERSION" + ) + if not ( + isinstance(protocol_version, Version) + and len(protocol_version) == 2 + and all(isinstance(i, int) for i in protocol_version) + ): + raise TypeError( + "PROTOCOL_VERSION must be a 2-tuple of integers, not " + f"{protocol_version!r}" + ) + if protocol_version in Bolt.protocol_handlers: + cls_conflict = Bolt.protocol_handlers[protocol_version] + raise TypeError( + f"Multiple classes for the same protocol version " + f"{protocol_version}: {cls}, {cls_conflict}" + ) + cls.protocol_handlers[protocol_version] = cls + super().__init_subclass__(**kwargs) @classmethod def get_handshake(cls): @@ -358,15 +308,9 @@ def get_handshake(cls): The length is 16 bytes as specified in the Bolt version negotiation. :returns: bytes """ - supported_versions = sorted( - cls.protocol_handlers().keys(), reverse=True - ) - offered_versions = cls.version_list(supported_versions, limit=3) - versions_bytes = ( - Version(0xFF, 1).to_bytes(), # handshake v2 - *(v.to_bytes() for v in offered_versions), + return ( + b"\x00\x00\x01\xff\x00\x06\x06\x05\x00\x04\x04\x04\x00\x00\x00\x03" ) - return b"".join(versions_bytes).ljust(16, b"\x00") @classmethod def ping(cls, address, *, deadline=None, pool_config=None): @@ -442,64 +386,16 @@ def open( ) pool_config.protocol_version = protocol_version - - # Carry out Bolt subclass imports locally to avoid circular dependency - # issues. - - # avoid new lines after imports for better readability and conciseness - # fmt: off - if protocol_version == (5, 6): - from ._bolt5 import Bolt5x6 - bolt_cls = Bolt5x6 - elif protocol_version == (5, 5): - from ._bolt5 import Bolt5x5 - bolt_cls = Bolt5x5 - elif protocol_version == (5, 4): - from ._bolt5 import Bolt5x4 - bolt_cls = Bolt5x4 - elif protocol_version == (5, 3): - from ._bolt5 import Bolt5x3 - bolt_cls = Bolt5x3 - elif protocol_version == (5, 2): - from ._bolt5 import Bolt5x2 - bolt_cls = Bolt5x2 - elif protocol_version == (5, 1): - from ._bolt5 import Bolt5x1 - bolt_cls = Bolt5x1 - elif protocol_version == (5, 0): - from ._bolt5 import Bolt5x0 - bolt_cls = Bolt5x0 - elif protocol_version == (4, 4): - from ._bolt4 import Bolt4x4 - bolt_cls = Bolt4x4 - elif protocol_version == (4, 3): - from ._bolt4 import Bolt4x3 - bolt_cls = Bolt4x3 - elif protocol_version == (4, 2): - from ._bolt4 import Bolt4x2 - bolt_cls = Bolt4x2 - # Implementations for exist, but there was no space left in the - # handshake to offer this version to the server. Hence, the server - # should never request us to speak these bolt versions. - # elif protocol_version == (4, 1): - # from ._bolt4 import AsyncBolt4x1 - # bolt_cls = AsyncBolt4x1 - # elif protocol_version == (4, 0): - # from ._bolt4 import AsyncBolt4x0 - # bolt_cls = AsyncBolt4x0 - elif protocol_version == (3, 0): - from ._bolt3 import Bolt3 - bolt_cls = Bolt3 - # fmt: on - else: + protocol_handlers = Bolt.protocol_handlers + bolt_cls = protocol_handlers.get(protocol_version) + if bolt_cls is None: log.debug("[#%04X] C: ", s.getsockname()[1]) BoltSocket.close_socket(s) - supported_versions = cls.protocol_handlers().keys() raise BoltHandshakeError( "The neo4j server does not support communication with this " "driver. This driver has support for Bolt protocols " - f"{tuple(map(str, supported_versions))}.", + f"{tuple(map(str, Bolt.protocol_handlers.keys()))}.", address=address, request_data=handshake, response_data=data, diff --git a/src/neo4j/api.py b/src/neo4j/api.py index cbf0c4189..1a4199c24 100644 --- a/src/neo4j/api.py +++ b/src/neo4j/api.py @@ -409,7 +409,13 @@ def update(self, metadata: dict) -> None: # TODO: 6.0 - this class should not be public. # As far the user is concerned, protocol versions should simply be a # tuple[int, int]. -class Version(tuple): +if t.TYPE_CHECKING: + _version_base = t.Tuple[int, int] +else: + _version_base = tuple + + +class Version(_version_base): def __new__(cls, *v): return super().__new__(cls, v) diff --git a/testkit/_common.py b/testkit/_common.py index 7e3e6a346..1c39b36de 100644 --- a/testkit/_common.py +++ b/testkit/_common.py @@ -47,6 +47,6 @@ def get_python_version(): def run_python(args, env=None, warning_as_error=True): cmd = [TEST_BACKEND_VERSION, "-u"] if warning_as_error: - cmd += ["-W", "error"] + cmd += ["-W", "error", "-X", "tracemalloc=10"] cmd += list(args) run(cmd, env=env) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index e4d8b14b5..779ad2ce1 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -47,7 +47,6 @@ "Feature:Auth:Kerberos": true, "Feature:Auth:Managed": true, "Feature:Bolt:3.0": true, - "Feature:Bolt:4.1": true, "Feature:Bolt:4.2": true, "Feature:Bolt:4.3": true, "Feature:Bolt:4.4": true, diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 5091a24b5..0a6b7fb5d 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -29,20 +29,17 @@ ) -# python -m pytest tests/unit/io/test_class_bolt.py -s -v - - # [bolt-version-bump] search tag when changing bolt version support def test_class_method_protocol_handlers(): # fmt: off expected_handlers = { (3, 0), - (4, 2), (4, 3), (4, 4), + (4, 0), (4, 1), (4, 2), (4, 3), (4, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), } # fmt: on - protocol_handlers = AsyncBolt.protocol_handlers() + protocol_handlers = AsyncBolt.protocol_handlers assert len(protocol_handlers) == len(expected_handlers) assert protocol_handlers.keys() == expected_handlers @@ -56,11 +53,13 @@ def test_class_method_protocol_handlers(): ((1, 0), False), ((2, 0), False), ((3, 0), True), - ((4, 0), False), - ((4, 1), False), + ((3, 1), False), + ((4, 0), True), + ((4, 1), True), ((4, 2), True), ((4, 3), True), ((4, 4), True), + ((4, 5), False), ((5, 0), True), ((5, 1), True), ((5, 2), True), @@ -75,7 +74,7 @@ def test_class_method_protocol_handlers(): def test_class_method_protocol_handlers_with_protocol_version( test_input, expected ): - protocol_handlers = AsyncBolt.protocol_handlers() + protocol_handlers = AsyncBolt.protocol_handlers assert (test_input in protocol_handlers) == expected @@ -84,7 +83,7 @@ def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() assert ( handshake - == b"\x00\x00\x01\xff\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x00\x03" + == b"\x00\x00\x01\xff\x00\x06\x06\x05\x00\x04\x04\x04\x00\x00\x00\x03" ) @@ -111,6 +110,10 @@ async def test_cancel_hello_in_open(mocker, none_auth): bolt_mock.socket = socket_mock bolt_mock.hello.side_effect = asyncio.CancelledError() bolt_mock.local_port = 1234 + mocker.patch.dict( + AsyncBolt.protocol_handlers, + {(5, 0): bolt_cls_mock}, + ) with pytest.raises(asyncio.CancelledError): await AsyncBolt.open(address, auth_manager=none_auth) @@ -123,6 +126,7 @@ async def test_cancel_hello_in_open(mocker, none_auth): ("bolt_version", "bolt_cls_path"), ( ((3, 0), "neo4j._async.io._bolt3.AsyncBolt3"), + ((4, 0), "neo4j._async.io._bolt4.AsyncBolt4x0"), ((4, 1), "neo4j._async.io._bolt4.AsyncBolt4x1"), ((4, 2), "neo4j._async.io._bolt4.AsyncBolt4x2"), ((4, 3), "neo4j._async.io._bolt4.AsyncBolt4x3"), @@ -157,6 +161,10 @@ async def test_version_negotiation( bolt_cls_mock.return_value.local_port = 1234 bolt_mock = bolt_cls_mock.return_value bolt_mock.socket = socket_mock + mocker.patch.dict( + AsyncBolt.protocol_handlers, + {bolt_version: bolt_cls_mock}, + ) connection = await AsyncBolt.open(address, auth_manager=none_auth) @@ -170,8 +178,6 @@ async def test_version_negotiation( ( (0, 0), (2, 0), - (4, 0), - (4, 1), (3, 1), (5, 7), (6, 0), @@ -180,7 +186,7 @@ async def test_version_negotiation( @mark_async_test async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( - "('3.0', '4.2', '4.3', '4.4', " + "('3.0', '4.0', '4.1', '4.2', '4.3', '4.4', " "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6')" ) @@ -214,12 +220,6 @@ async def test_cancel_manager_in_open(mocker): ) socket_cls_mock.connect.return_value = (socket_mock, (5, 0), None, None) socket_mock.getpeername.return_value = address - bolt_cls_mock = mocker.patch( - "neo4j._async.io._bolt5.AsyncBolt5x0", autospec=True - ) - bolt_mock = bolt_cls_mock.return_value - bolt_mock.socket = socket_mock - bolt_mock.local_port = 1234 auth_manager = mocker.AsyncMock( spec=neo4j.auth_management.AsyncAuthManager @@ -242,12 +242,6 @@ async def test_fail_manager_in_open(mocker): ) socket_cls_mock.connect.return_value = (socket_mock, (5, 0), None, None) socket_mock.getpeername.return_value = address - bolt_cls_mock = mocker.patch( - "neo4j._async.io._bolt5.AsyncBolt5x0", autospec=True - ) - bolt_mock = bolt_cls_mock.return_value - bolt_mock.socket = socket_mock - bolt_mock.local_port = 1234 auth_manager = mocker.AsyncMock( spec=neo4j.auth_management.AsyncAuthManager diff --git a/tests/unit/common/test_exceptions.py b/tests/unit/common/test_exceptions.py index 8cc192cdb..64cbb6ffb 100644 --- a/tests/unit/common/test_exceptions.py +++ b/tests/unit/common/test_exceptions.py @@ -73,7 +73,7 @@ def test_bolt_handshake_error(): b"\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00" ) response = b"\x00\x00\x00\x00" - supported_versions = Bolt.protocol_handlers().keys() + supported_versions = Bolt.protocol_handlers.keys() with pytest.raises(BoltHandshakeError) as e: error = BoltHandshakeError( diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index 73b91030d..31186a786 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -29,20 +29,17 @@ ) -# python -m pytest tests/unit/io/test_class_bolt.py -s -v - - # [bolt-version-bump] search tag when changing bolt version support def test_class_method_protocol_handlers(): # fmt: off expected_handlers = { (3, 0), - (4, 2), (4, 3), (4, 4), + (4, 0), (4, 1), (4, 2), (4, 3), (4, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), } # fmt: on - protocol_handlers = Bolt.protocol_handlers() + protocol_handlers = Bolt.protocol_handlers assert len(protocol_handlers) == len(expected_handlers) assert protocol_handlers.keys() == expected_handlers @@ -56,11 +53,13 @@ def test_class_method_protocol_handlers(): ((1, 0), False), ((2, 0), False), ((3, 0), True), - ((4, 0), False), - ((4, 1), False), + ((3, 1), False), + ((4, 0), True), + ((4, 1), True), ((4, 2), True), ((4, 3), True), ((4, 4), True), + ((4, 5), False), ((5, 0), True), ((5, 1), True), ((5, 2), True), @@ -75,7 +74,7 @@ def test_class_method_protocol_handlers(): def test_class_method_protocol_handlers_with_protocol_version( test_input, expected ): - protocol_handlers = Bolt.protocol_handlers() + protocol_handlers = Bolt.protocol_handlers assert (test_input in protocol_handlers) == expected @@ -84,7 +83,7 @@ def test_class_method_get_handshake(): handshake = Bolt.get_handshake() assert ( handshake - == b"\x00\x00\x01\xff\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x00\x03" + == b"\x00\x00\x01\xff\x00\x06\x06\x05\x00\x04\x04\x04\x00\x00\x00\x03" ) @@ -111,6 +110,10 @@ def test_cancel_hello_in_open(mocker, none_auth): bolt_mock.socket = socket_mock bolt_mock.hello.side_effect = asyncio.CancelledError() bolt_mock.local_port = 1234 + mocker.patch.dict( + Bolt.protocol_handlers, + {(5, 0): bolt_cls_mock}, + ) with pytest.raises(asyncio.CancelledError): Bolt.open(address, auth_manager=none_auth) @@ -123,6 +126,7 @@ def test_cancel_hello_in_open(mocker, none_auth): ("bolt_version", "bolt_cls_path"), ( ((3, 0), "neo4j._sync.io._bolt3.Bolt3"), + ((4, 0), "neo4j._sync.io._bolt4.Bolt4x0"), ((4, 1), "neo4j._sync.io._bolt4.Bolt4x1"), ((4, 2), "neo4j._sync.io._bolt4.Bolt4x2"), ((4, 3), "neo4j._sync.io._bolt4.Bolt4x3"), @@ -157,6 +161,10 @@ def test_version_negotiation( bolt_cls_mock.return_value.local_port = 1234 bolt_mock = bolt_cls_mock.return_value bolt_mock.socket = socket_mock + mocker.patch.dict( + Bolt.protocol_handlers, + {bolt_version: bolt_cls_mock}, + ) connection = Bolt.open(address, auth_manager=none_auth) @@ -170,8 +178,6 @@ def test_version_negotiation( ( (0, 0), (2, 0), - (4, 0), - (4, 1), (3, 1), (5, 7), (6, 0), @@ -180,7 +186,7 @@ def test_version_negotiation( @mark_sync_test def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( - "('3.0', '4.2', '4.3', '4.4', " + "('3.0', '4.0', '4.1', '4.2', '4.3', '4.4', " "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6')" ) @@ -214,12 +220,6 @@ def test_cancel_manager_in_open(mocker): ) socket_cls_mock.connect.return_value = (socket_mock, (5, 0), None, None) socket_mock.getpeername.return_value = address - bolt_cls_mock = mocker.patch( - "neo4j._sync.io._bolt5.Bolt5x0", autospec=True - ) - bolt_mock = bolt_cls_mock.return_value - bolt_mock.socket = socket_mock - bolt_mock.local_port = 1234 auth_manager = mocker.MagicMock( spec=neo4j.auth_management.AuthManager @@ -242,12 +242,6 @@ def test_fail_manager_in_open(mocker): ) socket_cls_mock.connect.return_value = (socket_mock, (5, 0), None, None) socket_mock.getpeername.return_value = address - bolt_cls_mock = mocker.patch( - "neo4j._sync.io._bolt5.Bolt5x0", autospec=True - ) - bolt_mock = bolt_cls_mock.return_value - bolt_mock.socket = socket_mock - bolt_mock.local_port = 1234 auth_manager = mocker.MagicMock( spec=neo4j.auth_management.AuthManager diff --git a/tox.ini b/tox.ini index d75274c2e..99eaab19f 100644 --- a/tox.ini +++ b/tox.ini @@ -4,13 +4,15 @@ envlist = py{37,38,39,310,311,312}-{unit,integration,performance} [testenv] passenv = TEST_NEO4J_* deps = -r requirements-dev.txt -setenv = COVERAGE_FILE={envdir}/.coverage +setenv = + COVERAGE_FILE={envdir}/.coverage + unit,performance,integration: PYTHONTRACEMALLOC = 10 usedevelop = true warnargs = - py{37,38,39,310,311,312}: -W error + py{37,38,39,310,311,312}: -W error -W ignore::pytest.PytestUnraisableExceptionWarning commands = coverage erase unit: coverage run -m pytest {[testenv]warnargs} -v {posargs} tests/unit integration: coverage run -m pytest {[testenv]warnargs} -v {posargs} tests/integration performance: python -m pytest --benchmark-autosave -v {posargs} tests/performance - unit,integration: coverage report +; unit,integration: coverage report From cb5dc26cac78cfc3321216cae988a98ccae49a52 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 16 Oct 2024 13:23:14 +0200 Subject: [PATCH 07/18] WIP --- testkitbackend/test_config.json | 1 - 1 file changed, 1 deletion(-) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index dc720619c..1b20f3583 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -47,7 +47,6 @@ "Feature:Auth:Kerberos": true, "Feature:Auth:Managed": true, "Feature:Bolt:3.0": true, - "Feature:Bolt:4.2": true, "Feature:Bolt:4.3": true, "Feature:Bolt:4.4": true, "Feature:Bolt:5.0": true, From 380c70cad47bfd827a1254b92958c510f44f138d Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 16 Oct 2024 13:23:14 +0200 Subject: [PATCH 08/18] WIP --- src/neo4j/_async/io/_bolt_socket.py | 8 ++++---- src/neo4j/_sync/io/_bolt_socket.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/neo4j/_async/io/_bolt_socket.py b/src/neo4j/_async/io/_bolt_socket.py index 6ca6eaa0c..843aaf58b 100644 --- a/src/neo4j/_async/io/_bolt_socket.py +++ b/src/neo4j/_async/io/_bolt_socket.py @@ -66,7 +66,7 @@ class AsyncBoltSocket(AsyncBoltSocketBase): async def _parse_handshake_response_v1(self, ctx, response): agreed_version = response[-1], response[-2] log.debug( - "[#%04X] S: 0x%06X%02X", + "[#%04X] S: 0x%06X%02X", ctx.local_port, agreed_version[1], agreed_version[0], @@ -87,7 +87,7 @@ async def _parse_handshake_response_v2(self, ctx, response): if log.getEffectiveLevel() >= logging.DEBUG: log.debug( - "[#%04X] S: %s [%i] %s %s", + "[#%04X] S: %s [%i] %s %s", ctx.local_port, BytesPrinter(response), num_offerings, @@ -97,7 +97,7 @@ async def _parse_handshake_response_v2(self, ctx, response): BytesPrinter(self._encode_varint(_capabilities_offer)), ) - supported_versions = sorted(self.Bolt.protocol_handlers().keys()) + supported_versions = sorted(self.Bolt.protocol_handlers.keys()) chosen_version = 0, 0 for v in supported_versions: for offer_major, offer_minor, offer_range in offerings: @@ -115,7 +115,7 @@ async def _parse_handshake_response_v2(self, ctx, response): capabilities = self._encode_varint(chosen_capabilities) ctx.ctx = "handshake v2 chosen capabilities" log.debug( - "[#%04X] C: 0x%06X%02X %s", + "[#%04X] C: 0x%06X%02X %s", ctx.local_port, chosen_version[1], chosen_version[0], diff --git a/src/neo4j/_sync/io/_bolt_socket.py b/src/neo4j/_sync/io/_bolt_socket.py index 160b12e3c..223f93397 100644 --- a/src/neo4j/_sync/io/_bolt_socket.py +++ b/src/neo4j/_sync/io/_bolt_socket.py @@ -66,7 +66,7 @@ class BoltSocket(BoltSocketBase): def _parse_handshake_response_v1(self, ctx, response): agreed_version = response[-1], response[-2] log.debug( - "[#%04X] S: 0x%06X%02X", + "[#%04X] S: 0x%06X%02X", ctx.local_port, agreed_version[1], agreed_version[0], @@ -87,7 +87,7 @@ def _parse_handshake_response_v2(self, ctx, response): if log.getEffectiveLevel() >= logging.DEBUG: log.debug( - "[#%04X] S: %s [%i] %s %s", + "[#%04X] S: %s [%i] %s %s", ctx.local_port, BytesPrinter(response), num_offerings, @@ -97,7 +97,7 @@ def _parse_handshake_response_v2(self, ctx, response): BytesPrinter(self._encode_varint(_capabilities_offer)), ) - supported_versions = sorted(self.Bolt.protocol_handlers().keys()) + supported_versions = sorted(self.Bolt.protocol_handlers.keys()) chosen_version = 0, 0 for v in supported_versions: for offer_major, offer_minor, offer_range in offerings: @@ -115,7 +115,7 @@ def _parse_handshake_response_v2(self, ctx, response): capabilities = self._encode_varint(chosen_capabilities) ctx.ctx = "handshake v2 chosen capabilities" log.debug( - "[#%04X] C: 0x%06X%02X %s", + "[#%04X] C: 0x%06X%02X %s", ctx.local_port, chosen_version[1], chosen_version[0], From 98a3c2e6b5bc906a9c668d5ca8ec5c5b4005a86f Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 23 Oct 2024 17:07:00 +0200 Subject: [PATCH 09/18] TestKit + tox config improvements --- requirements-dev.txt | 1 + src/neo4j/_async/io/__init__.py | 1 + src/neo4j/_sync/io/__init__.py | 1 + testkit/testkit.json | 2 +- testkit/unittests.py | 2 +- testkitbackend/test_config.json | 1 + tox.ini | 9 ++++++--- 7 files changed, 12 insertions(+), 5 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index ae3f9270f..fe3e3fe47 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -25,6 +25,7 @@ pytest-cov>=3.0.0 pytest-mock>=3.6.1 teamcity-messages>=1.29 tox>=4.0.0 +teamcity-messages>=1.32 # needed for building docs sphinx diff --git a/src/neo4j/_async/io/__init__.py b/src/neo4j/_async/io/__init__.py index b69f8377e..5797e242d 100644 --- a/src/neo4j/_async/io/__init__.py +++ b/src/neo4j/_async/io/__init__.py @@ -31,6 +31,7 @@ ] +# [bolt-version-bump] search tag when changing bolt version support from . import ( # noqa - imports needed to register protocol handlers _bolt3, _bolt4, diff --git a/src/neo4j/_sync/io/__init__.py b/src/neo4j/_sync/io/__init__.py index 4f9009d88..329b4b4cf 100644 --- a/src/neo4j/_sync/io/__init__.py +++ b/src/neo4j/_sync/io/__init__.py @@ -31,6 +31,7 @@ ] +# [bolt-version-bump] search tag when changing bolt version support from . import ( # noqa - imports needed to register protocol handlers _bolt3, _bolt4, diff --git a/testkit/testkit.json b/testkit/testkit.json index 931900356..9e496140c 100644 --- a/testkit/testkit.json +++ b/testkit/testkit.json @@ -1,6 +1,6 @@ { "testkit": { "uri": "https://github.com/neo4j-drivers/testkit.git", - "ref": "5.0" + "ref": "bolt-handshake-v2" } } diff --git a/testkit/unittests.py b/testkit/unittests.py index 4c8dbed4d..1652cd4ef 100644 --- a/testkit/unittests.py +++ b/testkit/unittests.py @@ -18,4 +18,4 @@ if __name__ == "__main__": - run_python(["-m", "tox", "-vv", "-f", "unit"]) + run_python(["-m", "tox", "-vv", "-p", "-f", "unit"]) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 1b20f3583..dc720619c 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -47,6 +47,7 @@ "Feature:Auth:Kerberos": true, "Feature:Auth:Managed": true, "Feature:Bolt:3.0": true, + "Feature:Bolt:4.2": true, "Feature:Bolt:4.3": true, "Feature:Bolt:4.4": true, "Feature:Bolt:5.0": true, diff --git a/tox.ini b/tox.ini index 1bc8ddbec..17568d624 100644 --- a/tox.ini +++ b/tox.ini @@ -4,17 +4,20 @@ envlist = py{37,38,39,310,311,312,313}-{unit,integration,performance} requires = virtualenv<20.22.0 [testenv] -passenv = TEST_NEO4J_* +passenv = + TEST_NEO4J_* + TEAMCITY_VERSION deps = -r requirements-dev.txt setenv = COVERAGE_FILE={envdir}/.coverage - unit,performance,integration: PYTHONTRACEMALLOC = 10 +; unit,performance,integration: PYTHONTRACEMALLOC = 5 usedevelop = true warnargs = py{37,38,39,310,311,312}: -W error -W ignore::pytest.PytestUnraisableExceptionWarning +parallel_show_output = true commands = coverage erase unit: coverage run -m pytest {[testenv]warnargs} -v {posargs} tests/unit integration: coverage run -m pytest {[testenv]warnargs} -v {posargs} tests/integration performance: python -m pytest --benchmark-autosave -v {posargs} tests/performance -; unit,integration: coverage report + unit,integration: coverage report From 7c231e259bb6943ea2104b33dce82411f4ab8db0 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 1 Nov 2024 09:43:17 +0100 Subject: [PATCH 10/18] Improve handshake logging --- src/neo4j/_async/io/_bolt_socket.py | 2 +- src/neo4j/_sync/io/_bolt_socket.py | 2 +- tox.ini | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/neo4j/_async/io/_bolt_socket.py b/src/neo4j/_async/io/_bolt_socket.py index 843aaf58b..f7964cbf0 100644 --- a/src/neo4j/_async/io/_bolt_socket.py +++ b/src/neo4j/_async/io/_bolt_socket.py @@ -242,7 +242,7 @@ async def _handshake(self, resolved_address, deadline): response = await self._handshake_read(ctx, 4) if response == b"HTTP": - log.debug("[#%04X] S: ", local_port) + log.debug("[#%04X] C: (received b'HTTP')", local_port) await self.close() raise ServiceUnavailable( f"Cannot to connect to Bolt service on {resolved_address!r} " diff --git a/src/neo4j/_sync/io/_bolt_socket.py b/src/neo4j/_sync/io/_bolt_socket.py index 223f93397..d07756f70 100644 --- a/src/neo4j/_sync/io/_bolt_socket.py +++ b/src/neo4j/_sync/io/_bolt_socket.py @@ -242,7 +242,7 @@ def _handshake(self, resolved_address, deadline): response = self._handshake_read(ctx, 4) if response == b"HTTP": - log.debug("[#%04X] S: ", local_port) + log.debug("[#%04X] C: (received b'HTTP')", local_port) self.close() raise ServiceUnavailable( f"Cannot to connect to Bolt service on {resolved_address!r} " diff --git a/tox.ini b/tox.ini index 98755d6ea..e3541a66d 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,7 @@ requires = virtualenv<20.22.0 [testenv] passenv = TEST_* +deps = -r requirements-dev.txt setenv = COVERAGE_FILE={envdir}/.coverage TEST_SUITE_NAME={envname} From 60d77c80f33cccb008b90978395d1607cbd885ad Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 4 Dec 2024 17:04:53 +0100 Subject: [PATCH 11/18] Add TestKit feature flag for handshake manifest v1 --- testkitbackend/test_config.json | 1 + 1 file changed, 1 insertion(+) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index dc720619c..45592b241 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -58,6 +58,7 @@ "Feature:Bolt:5.5": true, "Feature:Bolt:5.6": true, "Feature:Bolt:5.7": true, + "Feature:Bolt:HandshakeManifestV1": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", From 80ffed082cec01325cbc4b22b75b1e9654adc464 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 28 Jan 2025 17:04:12 +0100 Subject: [PATCH 12/18] Drop support for Bolt 4.0-4.2 --- CHANGELOG.md | 1 + src/neo4j/_async/io/_bolt.py | 7 ++++++- src/neo4j/_async/io/_bolt4.py | 4 ++++ src/neo4j/_sync/io/_bolt.py | 7 ++++++- src/neo4j/_sync/io/_bolt4.py | 4 ++++ tests/unit/async_/io/test_class_bolt.py | 16 ++++++++++------ tests/unit/sync/io/test_class_bolt.py | 16 ++++++++++------ 7 files changed, 41 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38f7dc563..58746839f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. ## NEXT RELEASE - Since the types of `Relationship`s are tied to the `Graph` object they belong to, fixing `pickle` support for graph types means that `Relationship`s with the same name will have a different type after `deepcopy`ing or pickling and unpickling them or their graph. For more details, see https://github.com/neo4j/neo4j-python-driver/pull/1133 +- Drop undocumented support for Bolt protocol versions 4.0 - 4.2. ## Version 5.27 diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 35bbe039f..6cf107773 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -136,6 +136,8 @@ class AsyncBolt: # results for it. most_recent_qid = None + SKIP_REGISTRATION = False + def __init__( self, unresolved_address, @@ -262,6 +264,9 @@ def assert_notification_filtering_support(self): protocol_handlers: t.ClassVar[dict[Version, type[AsyncBolt]]] = {} def __init_subclass__(cls: type[te.Self], **kwargs: t.Any) -> None: + if cls.SKIP_REGISTRATION: + super().__init_subclass__(**kwargs) + return protocol_version = cls.PROTOCOL_VERSION if protocol_version is None: raise ValueError( @@ -295,7 +300,7 @@ def get_handshake(cls): :returns: bytes """ return ( - b"\x00\x00\x01\xff\x00\x08\x08\x05\x00\x04\x04\x04\x00\x00\x00\x03" + b"\x00\x00\x01\xff\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x00\x03" ) @classmethod diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index abc9d4cbe..57f41f279 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -74,6 +74,8 @@ class AsyncBolt4x0(AsyncBolt): supports_notification_filtering = False + SKIP_REGISTRATION = True + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( @@ -560,6 +562,8 @@ class AsyncBolt4x3(AsyncBolt4x2): PROTOCOL_VERSION = Version(4, 3) + SKIP_REGISTRATION = False + def get_base_headers(self): headers = super().get_base_headers() headers["patch_bolt"] = ["utc"] diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 34e6fed8a..f6bc2e06f 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -136,6 +136,8 @@ class Bolt: # results for it. most_recent_qid = None + SKIP_REGISTRATION = False + def __init__( self, unresolved_address, @@ -262,6 +264,9 @@ def assert_notification_filtering_support(self): protocol_handlers: t.ClassVar[dict[Version, type[Bolt]]] = {} def __init_subclass__(cls: type[te.Self], **kwargs: t.Any) -> None: + if cls.SKIP_REGISTRATION: + super().__init_subclass__(**kwargs) + return protocol_version = cls.PROTOCOL_VERSION if protocol_version is None: raise ValueError( @@ -295,7 +300,7 @@ def get_handshake(cls): :returns: bytes """ return ( - b"\x00\x00\x01\xff\x00\x08\x08\x05\x00\x04\x04\x04\x00\x00\x00\x03" + b"\x00\x00\x01\xff\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x00\x03" ) @classmethod diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 99c04185a..a802016df 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -74,6 +74,8 @@ class Bolt4x0(Bolt): supports_notification_filtering = False + SKIP_REGISTRATION = True + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( @@ -560,6 +562,8 @@ class Bolt4x3(Bolt4x2): PROTOCOL_VERSION = Version(4, 3) + SKIP_REGISTRATION = False + def get_base_headers(self): headers = super().get_base_headers() headers["patch_bolt"] = ["utc"] diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 1fa68da35..410f19c93 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -34,7 +34,7 @@ def test_class_method_protocol_handlers(): # fmt: off expected_handlers = { (3, 0), - (4, 0), (4, 1), (4, 2), (4, 3), (4, 4), + (4, 3), (4, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on @@ -53,9 +53,9 @@ def test_class_method_protocol_handlers(): ((1, 0), 0), ((2, 0), 0), ((3, 0), 1), - ((4, 0), 1), - ((4, 1), 1), - ((4, 2), 1), + ((4, 0), 0), + ((4, 1), 0), + ((4, 2), 0), ((4, 3), 1), ((4, 4), 1), ((5, 0), 1), @@ -83,7 +83,7 @@ def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() assert ( handshake - == b"\x00\x00\x01\xff\x00\x08\x08\x05\x00\x04\x04\x04\x00\x00\x00\x03" + == b"\x00\x00\x01\xff\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x00\x03" ) @@ -179,8 +179,12 @@ async def test_version_negotiation( "bolt_version", ( (0, 0), + (1, 0), (2, 0), (3, 1), + (4, 0), + (4, 1), + (4, 2), (5, 9), (6, 0), ), @@ -188,7 +192,7 @@ async def test_version_negotiation( @mark_async_test async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( - "('3.0', '4.0', '4.1', '4.2', '4.3', '4.4', " + "('3.0', '4.3', '4.4', " "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8')" ) diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index cd3acde1e..0162128df 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -34,7 +34,7 @@ def test_class_method_protocol_handlers(): # fmt: off expected_handlers = { (3, 0), - (4, 0), (4, 1), (4, 2), (4, 3), (4, 4), + (4, 3), (4, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on @@ -53,9 +53,9 @@ def test_class_method_protocol_handlers(): ((1, 0), 0), ((2, 0), 0), ((3, 0), 1), - ((4, 0), 1), - ((4, 1), 1), - ((4, 2), 1), + ((4, 0), 0), + ((4, 1), 0), + ((4, 2), 0), ((4, 3), 1), ((4, 4), 1), ((5, 0), 1), @@ -83,7 +83,7 @@ def test_class_method_get_handshake(): handshake = Bolt.get_handshake() assert ( handshake - == b"\x00\x00\x01\xff\x00\x08\x08\x05\x00\x04\x04\x04\x00\x00\x00\x03" + == b"\x00\x00\x01\xff\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x00\x03" ) @@ -179,8 +179,12 @@ def test_version_negotiation( "bolt_version", ( (0, 0), + (1, 0), (2, 0), (3, 1), + (4, 0), + (4, 1), + (4, 2), (5, 9), (6, 0), ), @@ -188,7 +192,7 @@ def test_version_negotiation( @mark_sync_test def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( - "('3.0', '4.0', '4.1', '4.2', '4.3', '4.4', " + "('3.0', '4.3', '4.4', " "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8')" ) From 68e8ce65459af727482dd96d6268b73ff01c8cad Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 28 Jan 2025 18:45:55 +0100 Subject: [PATCH 13/18] Tests: fix event loop shutdown for Python <= 3.8 --- tests/conftest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 348546244..09758069d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -211,11 +211,12 @@ def watcher(): def event_loop(): policy = asyncio.get_event_loop_policy() loop = policy.new_event_loop() + yield loop try: - yield loop _cancel_all_tasks(loop) loop.run_until_complete(loop.shutdown_asyncgens()) - loop.run_until_complete(loop.shutdown_default_executor()) + if sys.version_info >= (3, 9): + loop.run_until_complete(loop.shutdown_default_executor()) finally: loop.close() From 89c70048e2299ba50fcdabd319c133dc40930ccd Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 28 Jan 2025 18:52:51 +0100 Subject: [PATCH 14/18] fixup! Drop support for Bolt 4.0-4.2 --- src/neo4j/_async/io/_bolt4.py | 4 ++-- src/neo4j/_sync/io/_bolt4.py | 4 ++-- testkitbackend/test_config.json | 2 ++ tests/unit/async_/io/test_class_bolt.py | 7 +++---- tests/unit/sync/io/test_class_bolt.py | 7 +++---- tox.ini | 2 +- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 57f41f279..138ee51bb 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -552,6 +552,8 @@ class AsyncBolt4x2(AsyncBolt4x1): PROTOCOL_VERSION = Version(4, 2) + SKIP_REGISTRATION = False + class AsyncBolt4x3(AsyncBolt4x2): """ @@ -562,8 +564,6 @@ class AsyncBolt4x3(AsyncBolt4x2): PROTOCOL_VERSION = Version(4, 3) - SKIP_REGISTRATION = False - def get_base_headers(self): headers = super().get_base_headers() headers["patch_bolt"] = ["utc"] diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index a802016df..19c719240 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -552,6 +552,8 @@ class Bolt4x2(Bolt4x1): PROTOCOL_VERSION = Version(4, 2) + SKIP_REGISTRATION = False + class Bolt4x3(Bolt4x2): """ @@ -562,8 +564,6 @@ class Bolt4x3(Bolt4x2): PROTOCOL_VERSION = Version(4, 3) - SKIP_REGISTRATION = False - def get_base_headers(self): headers = super().get_base_headers() headers["patch_bolt"] = ["utc"] diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index ef831d308..be37f132e 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -50,6 +50,8 @@ "Feature:Auth:Kerberos": true, "Feature:Auth:Managed": true, "Feature:Bolt:3.0": true, + "Feature:Bolt:4.0": "Dropped support legacy protocol version 4.0", + "Feature:Bolt:4.1": "Dropped support legacy protocol version 4.1", "Feature:Bolt:4.2": true, "Feature:Bolt:4.3": true, "Feature:Bolt:4.4": true, diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 410f19c93..e8b9e8886 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -34,7 +34,7 @@ def test_class_method_protocol_handlers(): # fmt: off expected_handlers = { (3, 0), - (4, 3), (4, 4), + (4, 2), (4, 3), (4, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on @@ -55,7 +55,7 @@ def test_class_method_protocol_handlers(): ((3, 0), 1), ((4, 0), 0), ((4, 1), 0), - ((4, 2), 0), + ((4, 2), 1), ((4, 3), 1), ((4, 4), 1), ((5, 0), 1), @@ -184,7 +184,6 @@ async def test_version_negotiation( (3, 1), (4, 0), (4, 1), - (4, 2), (5, 9), (6, 0), ), @@ -192,7 +191,7 @@ async def test_version_negotiation( @mark_async_test async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( - "('3.0', '4.3', '4.4', " + "('3.0', '4.2', '4.3', '4.4', " "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8')" ) diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index 0162128df..25a22ecb4 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -34,7 +34,7 @@ def test_class_method_protocol_handlers(): # fmt: off expected_handlers = { (3, 0), - (4, 3), (4, 4), + (4, 2), (4, 3), (4, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on @@ -55,7 +55,7 @@ def test_class_method_protocol_handlers(): ((3, 0), 1), ((4, 0), 0), ((4, 1), 0), - ((4, 2), 0), + ((4, 2), 1), ((4, 3), 1), ((4, 4), 1), ((5, 0), 1), @@ -184,7 +184,6 @@ def test_version_negotiation( (3, 1), (4, 0), (4, 1), - (4, 2), (5, 9), (6, 0), ), @@ -192,7 +191,7 @@ def test_version_negotiation( @mark_sync_test def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( - "('3.0', '4.3', '4.4', " + "('3.0', '4.2', '4.3', '4.4', " "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8')" ) diff --git a/tox.ini b/tox.ini index e3541a66d..c6f46f7c8 100644 --- a/tox.ini +++ b/tox.ini @@ -11,7 +11,7 @@ setenv = TEST_SUITE_NAME={envname} usedevelop = true warnargs = - py{37,38,39,310,311,312}: -W error -W ignore::pytest.PytestUnraisableExceptionWarning + -W error -W ignore::pytest.PytestUnraisableExceptionWarning parallel_show_output = true commands = coverage erase From c802344c1e5a1abe45c444ad8ac4303ae0f89c28 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 28 Jan 2025 19:50:31 +0100 Subject: [PATCH 15/18] amend! Drop support for Bolt 4.0-4.2 Drop support for Bolt 4.1 --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 58746839f..102b21301 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. ## NEXT RELEASE - Since the types of `Relationship`s are tied to the `Graph` object they belong to, fixing `pickle` support for graph types means that `Relationship`s with the same name will have a different type after `deepcopy`ing or pickling and unpickling them or their graph. For more details, see https://github.com/neo4j/neo4j-python-driver/pull/1133 -- Drop undocumented support for Bolt protocol versions 4.0 - 4.2. +- Drop undocumented support for Bolt protocol versions 4.1. ## Version 5.27 From 923e8f58554f07d93708a21e97e4893f9dbacd27 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 29 Jan 2025 09:25:01 +0100 Subject: [PATCH 16/18] Revert: Run against TestKit's 5.0 branch again --- testkit/testkit.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testkit/testkit.json b/testkit/testkit.json index 9e496140c..931900356 100644 --- a/testkit/testkit.json +++ b/testkit/testkit.json @@ -1,6 +1,6 @@ { "testkit": { "uri": "https://github.com/neo4j-drivers/testkit.git", - "ref": "bolt-handshake-v2" + "ref": "5.0" } } From 0db7ca31d3f76709450f8f4efc206ac26abb802e Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 4 Feb 2025 09:33:07 +0100 Subject: [PATCH 17/18] TestKit glue: remove tracemalloc option for faster testing --- testkit/_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testkit/_common.py b/testkit/_common.py index 1c39b36de..7e3e6a346 100644 --- a/testkit/_common.py +++ b/testkit/_common.py @@ -47,6 +47,6 @@ def get_python_version(): def run_python(args, env=None, warning_as_error=True): cmd = [TEST_BACKEND_VERSION, "-u"] if warning_as_error: - cmd += ["-W", "error", "-X", "tracemalloc=10"] + cmd += ["-W", "error"] cmd += list(args) run(cmd, env=env) From 20f488ff25c07caa72eb81577616abf5cf05fa27 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 4 Feb 2025 09:57:29 +0100 Subject: [PATCH 18/18] Update comments --- src/neo4j/_async/io/_bolt.py | 1 - src/neo4j/_sync/io/_bolt.py | 1 - tests/conftest.py | 10 ++++++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 6cf107773..b7a86ebfd 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -336,7 +336,6 @@ async def ping(cls, address, *, deadline=None, pool_config=None): await AsyncBoltSocket.close_socket(s) return protocol_version - # [bolt-version-bump] search tag when changing bolt version support @classmethod async def open( cls, diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index f6bc2e06f..39b72d9c8 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -336,7 +336,6 @@ def ping(cls, address, *, deadline=None, pool_config=None): BoltSocket.close_socket(s) return protocol_version - # [bolt-version-bump] search tag when changing bolt version support @classmethod def open( cls, diff --git a/tests/conftest.py b/tests/conftest.py index 09758069d..b6cdd7f7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -209,6 +209,9 @@ def watcher(): # check if this fixture is still needed @pytest.fixture def event_loop(): + # Overwriting the default event loop injected by pytest-asyncio + # because its implementation doesn't properly shut down the loop + # (e.g., it doesn't call `shutdown_asyncgens`) policy = asyncio.get_event_loop_policy() loop = policy.new_event_loop() yield loop @@ -222,6 +225,13 @@ def event_loop(): def _cancel_all_tasks(loop): + # Copied from Python 3.13's asyncio package with minor modifications + # in exception wording and variable naming + + # Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, + # 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022 + # Python Software Foundation; + # All Rights Reserved to_cancel = asyncio.all_tasks(loop) if not to_cancel: return