diff --git a/CHANGELOG.md b/CHANGELOG.md index 38f7dc563..102b21301 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. ## NEXT RELEASE - Since the types of `Relationship`s are tied to the `Graph` object they belong to, fixing `pickle` support for graph types means that `Relationship`s with the same name will have a different type after `deepcopy`ing or pickling and unpickling them or their graph. For more details, see https://github.com/neo4j/neo4j-python-driver/pull/1133 +- Drop undocumented support for Bolt protocol versions 4.1. ## Version 5.27 diff --git a/pyproject.toml b/pyproject.toml index 06affa14f..2cb079452 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,7 @@ use_parentheses = true [tool.pytest.ini_options] mock_use_standalone_module = true -asyncio_mode = "auto" +asyncio_mode = "strict" [tool.mypy] diff --git a/src/neo4j/_async/io/__init__.py b/src/neo4j/_async/io/__init__.py index 7f068da77..fddadd1d1 100644 --- a/src/neo4j/_async/io/__init__.py +++ b/src/neo4j/_async/io/__init__.py @@ -32,6 +32,12 @@ ] +# [bolt-version-bump] search tag when changing bolt version support +from . import ( # noqa - imports needed to register protocol handlers + _bolt3, + _bolt4, + _bolt5, +) from ._bolt import AsyncBolt from ._common import ( check_supported_server_product, diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index b2a73977d..b7a86ebfd 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -23,7 +23,6 @@ from logging import getLogger from time import monotonic -from ..._async_compat.network import AsyncBoltSocket from ..._async_compat.util import AsyncUtil from ..._auth_management import to_auth_dict from ..._codec.hydration import ( @@ -51,6 +50,7 @@ SessionExpired, ) from ..config import AsyncPoolConfig +from ._bolt_socket import AsyncBoltSocket from ._common import ( AsyncInbox, AsyncOutbox, @@ -59,6 +59,8 @@ if t.TYPE_CHECKING: + import typing_extensions as te + from ..._api import TelemetryAPI @@ -134,6 +136,8 @@ class AsyncBolt: # results for it. most_recent_qid = None + SKIP_REGISTRATION = False + def __init__( self, unresolved_address, @@ -257,101 +261,36 @@ def assert_notification_filtering_support(self): f"{self.server_info.agent!r}" ) - # [bolt-version-bump] search tag when changing bolt version support - @classmethod - def protocol_handlers(cls, protocol_version=None): - """ - Return a dictionary of available Bolt protocol handlers. - - The handlers are keyed by version tuple. If an explicit protocol - version is provided, the dictionary will contain either zero or one - items, depending on whether that version is supported. If no protocol - version is provided, all available versions will be returned. - - :param protocol_version: tuple identifying a specific protocol - version (e.g. (3, 5)) or None - :returns: dictionary of version tuple to handler class for all - relevant and supported protocol versions - :raise TypeError: if protocol version is not passed in a tuple - """ - # Carry out Bolt subclass imports locally to avoid circular dependency - # issues. - from ._bolt3 import AsyncBolt3 - from ._bolt4 import ( - AsyncBolt4x1, - AsyncBolt4x2, - AsyncBolt4x3, - AsyncBolt4x4, - ) - from ._bolt5 import ( - AsyncBolt5x0, - AsyncBolt5x1, - AsyncBolt5x2, - AsyncBolt5x3, - AsyncBolt5x4, - AsyncBolt5x5, - AsyncBolt5x6, - AsyncBolt5x7, - AsyncBolt5x8, - ) - - handlers = { - AsyncBolt3.PROTOCOL_VERSION: AsyncBolt3, - # 4.0 unsupported because no space left in the handshake - AsyncBolt4x1.PROTOCOL_VERSION: AsyncBolt4x1, - AsyncBolt4x2.PROTOCOL_VERSION: AsyncBolt4x2, - AsyncBolt4x3.PROTOCOL_VERSION: AsyncBolt4x3, - AsyncBolt4x4.PROTOCOL_VERSION: AsyncBolt4x4, - AsyncBolt5x0.PROTOCOL_VERSION: AsyncBolt5x0, - AsyncBolt5x1.PROTOCOL_VERSION: AsyncBolt5x1, - AsyncBolt5x2.PROTOCOL_VERSION: AsyncBolt5x2, - AsyncBolt5x3.PROTOCOL_VERSION: AsyncBolt5x3, - AsyncBolt5x4.PROTOCOL_VERSION: AsyncBolt5x4, - AsyncBolt5x5.PROTOCOL_VERSION: AsyncBolt5x5, - AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6, - AsyncBolt5x7.PROTOCOL_VERSION: AsyncBolt5x7, - AsyncBolt5x8.PROTOCOL_VERSION: AsyncBolt5x8, - } + protocol_handlers: t.ClassVar[dict[Version, type[AsyncBolt]]] = {} + def __init_subclass__(cls: type[te.Self], **kwargs: t.Any) -> None: + if cls.SKIP_REGISTRATION: + super().__init_subclass__(**kwargs) + return + protocol_version = cls.PROTOCOL_VERSION if protocol_version is None: - return handlers - - if not isinstance(protocol_version, tuple): - raise TypeError("Protocol version must be specified as a tuple") - - if protocol_version in handlers: - return {protocol_version: handlers[protocol_version]} - - return {} - - @classmethod - def version_list(cls, versions): - """ - Return a list of supported protocol versions in order of preference. - - The number of protocol versions (or ranges) returned is limited to 4. - """ - # In fact, 4.3 is the fist version to support ranges. However, the - # range support got backported to 4.2. But even if the server is too - # old to have the backport, negotiating BOLT 4.1 is no problem as it's - # equivalent to 4.2 - first_with_range_support = Version(4, 2) - result = [] - for version in versions: - if ( - result - and version >= first_with_range_support - and result[-1][0] == version[0] - and result[-1][1][1] == version[1] + 1 - ): - # can use range to encompass this version - result[-1][1][1] = version[1] - continue - result.append(Version(version[0], [version[1], version[1]])) - if len(result) == 4: - break - return result + raise ValueError( + "AsyncBolt subclasses must define PROTOCOL_VERSION" + ) + if not ( + isinstance(protocol_version, Version) + and len(protocol_version) == 2 + and all(isinstance(i, int) for i in protocol_version) + ): + raise TypeError( + "PROTOCOL_VERSION must be a 2-tuple of integers, not " + f"{protocol_version!r}" + ) + if protocol_version in AsyncBolt.protocol_handlers: + cls_conflict = AsyncBolt.protocol_handlers[protocol_version] + raise TypeError( + f"Multiple classes for the same protocol version " + f"{protocol_version}: {cls}, {cls_conflict}" + ) + cls.protocol_handlers[protocol_version] = cls + super().__init_subclass__(**kwargs) + # [bolt-version-bump] search tag when changing bolt version support @classmethod def get_handshake(cls): """ @@ -360,12 +299,9 @@ def get_handshake(cls): The length is 16 bytes as specified in the Bolt version negotiation. :returns: bytes """ - supported_versions = sorted( - cls.protocol_handlers().keys(), reverse=True + return ( + b"\x00\x00\x01\xff\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x00\x03" ) - offered_versions = cls.version_list(supported_versions) - versions_bytes = (v.to_bytes() for v in offered_versions) - return b"".join(versions_bytes).ljust(16, b"\x00") @classmethod async def ping(cls, address, *, deadline=None, pool_config=None): @@ -400,7 +336,6 @@ async def ping(cls, address, *, deadline=None, pool_config=None): await AsyncBoltSocket.close_socket(s) return protocol_version - # [bolt-version-bump] search tag when changing bolt version support @classmethod async def open( cls, @@ -441,71 +376,17 @@ async def open( ) pool_config.protocol_version = protocol_version - - # Carry out Bolt subclass imports locally to avoid circular dependency - # issues. - - # avoid new lines after imports for better readability and conciseness - # fmt: off - if protocol_version == (5, 8): - from ._bolt5 import AsyncBolt5x8 - bolt_cls = AsyncBolt5x8 - elif protocol_version == (5, 7): - from ._bolt5 import AsyncBolt5x7 - bolt_cls = AsyncBolt5x7 - elif protocol_version == (5, 6): - from ._bolt5 import AsyncBolt5x6 - bolt_cls = AsyncBolt5x6 - elif protocol_version == (5, 5): - from ._bolt5 import AsyncBolt5x5 - bolt_cls = AsyncBolt5x5 - elif protocol_version == (5, 4): - from ._bolt5 import AsyncBolt5x4 - bolt_cls = AsyncBolt5x4 - elif protocol_version == (5, 3): - from ._bolt5 import AsyncBolt5x3 - bolt_cls = AsyncBolt5x3 - elif protocol_version == (5, 2): - from ._bolt5 import AsyncBolt5x2 - bolt_cls = AsyncBolt5x2 - elif protocol_version == (5, 1): - from ._bolt5 import AsyncBolt5x1 - bolt_cls = AsyncBolt5x1 - elif protocol_version == (5, 0): - from ._bolt5 import AsyncBolt5x0 - bolt_cls = AsyncBolt5x0 - elif protocol_version == (4, 4): - from ._bolt4 import AsyncBolt4x4 - bolt_cls = AsyncBolt4x4 - elif protocol_version == (4, 3): - from ._bolt4 import AsyncBolt4x3 - bolt_cls = AsyncBolt4x3 - elif protocol_version == (4, 2): - from ._bolt4 import AsyncBolt4x2 - bolt_cls = AsyncBolt4x2 - elif protocol_version == (4, 1): - from ._bolt4 import AsyncBolt4x1 - bolt_cls = AsyncBolt4x1 - # Implementation for 4.0 exists, but there was no space left in the - # handshake to offer this version to the server. Hence, the server - # should never request us to speak bolt 4.0. - # elif protocol_version == (4, 0): - # from ._bolt4 import AsyncBolt4x0 - # bolt_cls = AsyncBolt4x0 - elif protocol_version == (3, 0): - from ._bolt3 import AsyncBolt3 - bolt_cls = AsyncBolt3 - # fmt: on - else: + protocol_handlers = AsyncBolt.protocol_handlers + bolt_cls = protocol_handlers.get(protocol_version) + if bolt_cls is None: log.debug("[#%04X] C: ", s.getsockname()[1]) await AsyncBoltSocket.close_socket(s) - supported_versions = cls.protocol_handlers().keys() # TODO: 6.0 - raise public DriverError subclass instead raise BoltHandshakeError( "The neo4j server does not support communication with this " "driver. This driver has support for Bolt protocols " - f"{tuple(map(str, supported_versions))}.", + f"{tuple(map(str, AsyncBolt.protocol_handlers.keys()))}.", address=address, request_data=handshake, response_data=data, diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index abc9d4cbe..138ee51bb 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -74,6 +74,8 @@ class AsyncBolt4x0(AsyncBolt): supports_notification_filtering = False + SKIP_REGISTRATION = True + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( @@ -550,6 +552,8 @@ class AsyncBolt4x2(AsyncBolt4x1): PROTOCOL_VERSION = Version(4, 2) + SKIP_REGISTRATION = False + class AsyncBolt4x3(AsyncBolt4x2): """ diff --git a/src/neo4j/_async/io/_bolt_socket.py b/src/neo4j/_async/io/_bolt_socket.py new file mode 100644 index 000000000..f7964cbf0 --- /dev/null +++ b/src/neo4j/_async/io/_bolt_socket.py @@ -0,0 +1,359 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import asyncio +import dataclasses +import logging +import struct +import typing as t +from contextlib import suppress + +from ... import addressing +from ..._async_compat.network import ( + AsyncBoltSocketBase, + AsyncNetworkUtil, +) +from ..._exceptions import ( + BoltError, + BoltProtocolError, +) +from ...exceptions import ( + DriverError, + ServiceUnavailable, +) + + +if t.TYPE_CHECKING: + from ..._deadline import Deadline + + +log = logging.getLogger("neo4j.io") + + +@dataclasses.dataclass +class HandshakeCtx: + ctx: str + deadline: Deadline + local_port: int + resolved_address: addressing.ResolvedAddress + full_response: bytearray = dataclasses.field(default_factory=bytearray) + + +@dataclasses.dataclass +class BytesPrinter: + bytes: bytes | bytearray + + def __str__(self): + return f"0x{self.bytes.hex().upper()}" + + +class AsyncBoltSocket(AsyncBoltSocketBase): + async def _parse_handshake_response_v1(self, ctx, response): + agreed_version = response[-1], response[-2] + log.debug( + "[#%04X] S: 0x%06X%02X", + ctx.local_port, + agreed_version[1], + agreed_version[0], + ) + return agreed_version + + async def _parse_handshake_response_v2(self, ctx, response): + ctx.ctx = "handshake v2 offerings count" + num_offerings = await self._read_varint(ctx) + offerings = [] + for i in range(num_offerings): + ctx.ctx = f"handshake v2 offering {i}" + offering_response = await self._handshake_read(ctx, 4) + offering = offering_response[-1:-4:-1] + offerings.append(offering) + ctx.ctx = "handshake v2 capabilities" + _capabilities_offer = await self._read_varint(ctx) + + if log.getEffectiveLevel() >= logging.DEBUG: + log.debug( + "[#%04X] S: %s [%i] %s %s", + ctx.local_port, + BytesPrinter(response), + num_offerings, + " ".join( + f"0x{vx[2]:04X}{vx[1]:02X}{vx[0]:02X}" for vx in offerings + ), + BytesPrinter(self._encode_varint(_capabilities_offer)), + ) + + supported_versions = sorted(self.Bolt.protocol_handlers.keys()) + chosen_version = 0, 0 + for v in supported_versions: + for offer_major, offer_minor, offer_range in offerings: + offer_max = (offer_major, offer_minor) + offer_min = (offer_major, offer_minor - offer_range) + if offer_min <= v <= offer_max: + chosen_version = v + break + + ctx.ctx = "handshake v2 chosen version" + await self._handshake_send( + ctx, bytes((0, 0, chosen_version[1], chosen_version[0])) + ) + chosen_capabilities = 0 + capabilities = self._encode_varint(chosen_capabilities) + ctx.ctx = "handshake v2 chosen capabilities" + log.debug( + "[#%04X] C: 0x%06X%02X %s", + ctx.local_port, + chosen_version[1], + chosen_version[0], + BytesPrinter(capabilities), + ) + await self._handshake_send(ctx, b"\x00") + + return chosen_version + + async def _read_varint(self, ctx): + next_byte = (await self._handshake_read(ctx, 1))[0] + res = next_byte & 0x7F + i = 0 + while next_byte & 0x80: + i += 1 + next_byte = (await self._handshake_read(ctx, 1))[0] + res += (next_byte & 0x7F) << (7 * i) + return res + + @staticmethod + def _encode_varint(n): + res = bytearray() + while n >= 0x80: + res.append(n & 0x7F | 0x80) + n >>= 7 + res.append(n) + return res + + async def _handshake_read(self, ctx, n): + original_timeout = self.gettimeout() + self.settimeout(ctx.deadline.to_timeout()) + try: + response = await self.recv(n) + ctx.full_response.extend(response) + except OSError as exc: + raise ServiceUnavailable( + f"Failed to read {ctx.ctx} from server " + f"{ctx.resolved_address!r} (deadline {ctx.deadline})" + ) from exc + finally: + self.settimeout(original_timeout) + data_size = len(response) + if data_size == 0: + # If no data is returned after a successful select + # response, the server has closed the connection + log.debug("[#%04X] S: ", ctx.local_port) + await self.close() + raise ServiceUnavailable( + f"Connection to {ctx.resolved_address} closed with incomplete " + f"handshake response" + ) + if data_size != n: + # Some garbled data has been received + log.debug("[#%04X] S: @*#!", ctx.local_port) + await self.close() + raise BoltProtocolError( + f"Expected {ctx.ctx} from {ctx.resolved_address!r}, received " + f"{response!r} instead (so far {ctx.full_response!r}); " + "check for incorrect port number", + address=ctx.resolved_address, + ) + + return response + + async def _handshake_send(self, ctx, data): + original_timeout = self.gettimeout() + self.settimeout(ctx.deadline.to_timeout()) + try: + await self.sendall(data) + except OSError as exc: + raise ServiceUnavailable( + f"Failed to write {ctx.ctx} to server " + f"{ctx.resolved_address!r} (deadline {ctx.deadline})" + ) from exc + finally: + self.settimeout(original_timeout) + + async def _handshake(self, resolved_address, deadline): + """ + Perform BOLT handshake. + + :param resolved_address: + :param deadline: Deadline for handshake + + :returns: (version, client_handshake, server_response_data) + """ + local_port = self.getsockname()[1] + + if log.getEffectiveLevel() >= logging.DEBUG: + handshake = self.Bolt.get_handshake() + handshake = struct.unpack(">16B", handshake) + handshake = [ + handshake[i : i + 4] for i in range(0, len(handshake), 4) + ] + + supported_versions = [ + f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}" + for vx in handshake + ] + + log.debug( + "[#%04X] C: 0x%08X", + local_port, + int.from_bytes(self.Bolt.MAGIC_PREAMBLE, byteorder="big"), + ) + log.debug( + "[#%04X] C: %s %s %s %s", + local_port, + *supported_versions, + ) + + request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() + + ctx = HandshakeCtx( + ctx="handshake opening", + deadline=deadline, + local_port=local_port, + resolved_address=resolved_address, + ) + + await self._handshake_send(ctx, request) + + ctx.ctx = "four byte Bolt handshake response" + response = await self._handshake_read(ctx, 4) + + if response == b"HTTP": + log.debug("[#%04X] C: (received b'HTTP')", local_port) + await self.close() + raise ServiceUnavailable( + f"Cannot to connect to Bolt service on {resolved_address!r} " + "(looks like HTTP)" + ) + elif response[-1] == 0xFF: + # manifest style handshake + manifest_version = response[-2] + if manifest_version == 0x01: + agreed_version = await self._parse_handshake_response_v2( + ctx, + response, + ) + else: + raise BoltProtocolError( + "Unsupported Bolt handshake manifest version " + f"{manifest_version} received from {resolved_address!r}.", + address=resolved_address, + ) + else: + agreed_version = await self._parse_handshake_response_v1( + ctx, + response, + ) + + return agreed_version, handshake, response + + @classmethod + async def connect( + cls, + address, + *, + tcp_timeout, + deadline, + custom_resolver, + ssl_context, + keep_alive, + ): + """ + Connect and perform a handshake. + + Return a valid Connection object, assuming a protocol version can be + agreed. + """ + errors = [] + failed_addresses = [] + # Establish a connection to the host and port specified + # Catches refused connections see: + # https://docs.python.org/2/library/errno.html + + resolved_addresses = AsyncNetworkUtil.resolve_address( + addressing.Address(address), resolver=custom_resolver + ) + async for resolved_address in resolved_addresses: + deadline_timeout = deadline.to_timeout() + if ( + deadline_timeout is not None + and deadline_timeout <= tcp_timeout + ): + tcp_timeout = deadline_timeout + s = None + try: + s = await cls._connect_secure( + resolved_address, tcp_timeout, keep_alive, ssl_context + ) + agreed_version, handshake, response = await s._handshake( + resolved_address, deadline + ) + return s, agreed_version, handshake, response + except (BoltError, DriverError, OSError) as error: + try: + local_port = s.getsockname()[1] + except (OSError, AttributeError, TypeError): + local_port = 0 + err_str = error.__class__.__name__ + if str(error): + err_str += ": " + str(error) + log.debug( + "[#%04X] S: %s %s", + local_port, + resolved_address, + err_str, + ) + if s: + await cls.close_socket(s) + errors.append(error) + failed_addresses.append(resolved_address) + except asyncio.CancelledError: + try: + local_port = s.getsockname()[1] + except (OSError, AttributeError, TypeError): + local_port = 0 + log.debug( + "[#%04X] C: %s", local_port, resolved_address + ) + if s: + with suppress(OSError): + s.kill() + raise + except Exception: + if s: + await cls.close_socket(s) + raise + address_strs = tuple(map(str, failed_addresses)) + if not errors: + raise ServiceUnavailable( + f"Couldn't connect to {address} (resolved to {address_strs})" + ) + else: + error_strs = "\n".join(map(str, errors)) + raise ServiceUnavailable( + f"Couldn't connect to {address} (resolved to {address_strs}):" + f"\n{error_strs}" + ) from errors[0] diff --git a/src/neo4j/_async_compat/network/__init__.py b/src/neo4j/_async_compat/network/__init__.py index e5777595f..4da0f32e6 100644 --- a/src/neo4j/_async_compat/network/__init__.py +++ b/src/neo4j/_async_compat/network/__init__.py @@ -15,8 +15,8 @@ from ._bolt_socket import ( - AsyncBoltSocket, - BoltSocket, + AsyncBoltSocketBase, + BoltSocketBase, ) from ._util import ( AsyncNetworkUtil, @@ -25,8 +25,8 @@ __all__ = [ - "AsyncBoltSocket", + "AsyncBoltSocketBase", "AsyncNetworkUtil", - "BoltSocket", + "BoltSocketBase", "NetworkUtil", ] diff --git a/src/neo4j/_async_compat/network/_bolt_socket.py b/src/neo4j/_async_compat/network/_bolt_socket.py index 357e7410f..cccd16437 100644 --- a/src/neo4j/_async_compat/network/_bolt_socket.py +++ b/src/neo4j/_async_compat/network/_bolt_socket.py @@ -16,10 +16,10 @@ from __future__ import annotations +import abc import asyncio import errno import logging -import struct import typing as t from contextlib import suppress @@ -47,23 +47,14 @@ SSLSocket, ) -from ... import addressing from ..._deadline import Deadline from ..._exceptions import ( - BoltError, BoltProtocolError, BoltSecurityError, SocketDeadlineExceededError, ) -from ...exceptions import ( - DriverError, - ServiceUnavailable, -) +from ...exceptions import ServiceUnavailable from ..shims import wait_for -from ._util import ( - AsyncNetworkUtil, - NetworkUtil, -) if t.TYPE_CHECKING: @@ -85,7 +76,7 @@ def _sanitize_deadline(deadline): return deadline -class AsyncBoltSocket: +class AsyncBoltSocketBase(abc.ABC): Bolt: te.Final[type[AsyncBolt]] = None # type: ignore[assignment] def __init__(self, reader, protocol, writer): @@ -178,19 +169,22 @@ def kill(self): self._writer.close() @classmethod - async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): + async def _connect_secure( + cls, resolved_address, timeout, keep_alive, ssl_context + ): """ Connect to the address and return the socket. :param resolved_address: :param timeout: seconds :param keep_alive: True or False - :param ssl: SSLContext or None + :param ssl_context: SSLContext or None :returns: AsyncBoltSocket object """ loop = asyncio.get_event_loop() s = None + local_port = 0 # TODO: tomorrow me: fix this mess try: @@ -203,17 +197,18 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): s.setblocking(False) # asyncio + blocking = no-no! log.debug("[#0000] C: %s", resolved_address) await wait_for(loop.sock_connect(s, resolved_address), timeout) + local_port = s.getsockname()[1] keep_alive = 1 if keep_alive else 0 s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) ssl_kwargs = {} - if ssl is not None: + if ssl_context is not None: hostname = resolved_address._host_name or None - ssl_kwargs.update( - ssl=ssl, server_hostname=hostname if HAS_SNI else None - ) + sni_host = hostname if HAS_SNI and hostname else None + ssl_kwargs.update(ssl=ssl_context, server_hostname=sni_host) + log.debug("[#%04X] C: %s", local_port, hostname) reader = asyncio.StreamReader( limit=2**16, # 64 KiB, @@ -225,7 +220,7 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): ) writer = asyncio.StreamWriter(transport, protocol, reader, loop) - if ssl is not None: + if ssl_context is not None: # Check that the server provides a certificate der_encoded_server_certificate = transport.get_extra_info( "ssl_object" @@ -256,7 +251,6 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): cls._kill_raw_socket(s) raise except (SSLError, CertificateError) as error: - local_port = s.getsockname()[1] if s: cls._kill_raw_socket(s) raise BoltSecurityError( @@ -279,92 +273,25 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): ) from error raise - async def _handshake(self, resolved_address, deadline): - """ - Perform BOLT handshake. - - :param resolved_address: - :param deadline: Deadline for handshake + @abc.abstractmethod + async def _handshake(self, resolved_address, deadline): ... - :returns: (socket, version, client_handshake, server_response_data) - """ - local_port = self.getsockname()[1] - - # TODO: Optimize logging code - handshake = self.Bolt.get_handshake() - handshake = struct.unpack(">16B", handshake) - handshake = [handshake[i : i + 4] for i in range(0, len(handshake), 4)] - - supported_versions = [ - f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}" - for vx in handshake - ] - - log.debug( - "[#%04X] C: 0x%08X", - local_port, - int.from_bytes(self.Bolt.MAGIC_PREAMBLE, byteorder="big"), - ) - log.debug( - "[#%04X] C: %s %s %s %s", - local_port, - *supported_versions, - ) - - request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() - - # Handle the handshake response - original_timeout = self.gettimeout() - self.settimeout(deadline.to_timeout()) - try: - await self.sendall(request) - response = await self.recv(4) - except OSError as exc: - raise ServiceUnavailable( - f"Failed to read any data from server {resolved_address!r} " - f"after connected (deadline {deadline})" - ) from exc - finally: - self.settimeout(original_timeout) - data_size = len(response) - if data_size == 0: - # If no data is returned after a successful select - # response, the server has closed the connection - log.debug("[#%04X] S: ", local_port) - await self.close() - raise ServiceUnavailable( - f"Connection to {resolved_address} closed without handshake " - "response" - ) - if data_size != 4: - # Some garbled data has been received - log.debug("[#%04X] S: @*#!", local_port) - await self.close() - raise BoltProtocolError( - "Expected four byte Bolt handshake response from " - f"{resolved_address!r}, received {response!r} instead; " - "check for incorrect port number", - address=resolved_address, - ) - elif response == b"HTTP": - log.debug("[#%04X] S: ", local_port) - await self.close() - raise ServiceUnavailable( - f"Cannot to connect to Bolt service on {resolved_address!r} " - "(looks like HTTP)" - ) - agreed_version = response[-1], response[-2] - log.debug( - "[#%04X] S: 0x%06X%02X", - local_port, - agreed_version[1], - agreed_version[0], - ) - return self, agreed_version, handshake, response + @classmethod + @abc.abstractmethod + async def connect( + cls, + address, + *, + tcp_timeout, + deadline, + custom_resolver, + ssl_context, + keep_alive, + ): ... @classmethod async def close_socket(cls, socket_): - if isinstance(socket_, AsyncBoltSocket): + if isinstance(socket_, AsyncBoltSocketBase): with suppress(OSError): await socket_.close() else: @@ -377,93 +304,8 @@ def _kill_raw_socket(cls, socket_): with suppress(OSError): socket_.close() - @classmethod - async def connect( - cls, - address, - *, - tcp_timeout, - deadline, - custom_resolver, - ssl_context, - keep_alive, - ): - """ - Connect and perform a handshake. - Return a valid Connection object, assuming a protocol version can be - agreed. - """ - errors = [] - failed_addresses = [] - # Establish a connection to the host and port specified - # Catches refused connections see: - # https://docs.python.org/2/library/errno.html - - resolved_addresses = AsyncNetworkUtil.resolve_address( - addressing.Address(address), resolver=custom_resolver - ) - async for resolved_address in resolved_addresses: - deadline_timeout = deadline.to_timeout() - if ( - deadline_timeout is not None - and deadline_timeout <= tcp_timeout - ): - tcp_timeout = deadline_timeout - s = None - try: - s = await cls._connect_secure( - resolved_address, tcp_timeout, keep_alive, ssl_context - ) - return await s._handshake(resolved_address, deadline) - except (BoltError, DriverError, OSError) as error: - try: - local_port = s.getsockname()[1] - except (OSError, AttributeError, TypeError): - local_port = 0 - err_str = error.__class__.__name__ - if str(error): - err_str += ": " + str(error) - log.debug( - "[#%04X] C: %s %s", - local_port, - resolved_address, - err_str, - ) - if s: - await cls.close_socket(s) - errors.append(error) - failed_addresses.append(resolved_address) - except asyncio.CancelledError: - try: - local_port = s.getsockname()[1] - except (OSError, AttributeError, TypeError): - local_port = 0 - log.debug( - "[#%04X] C: %s", local_port, resolved_address - ) - if s: - with suppress(OSError): - s.kill() - raise - except Exception: - if s: - await cls.close_socket(s) - raise - address_strs = tuple(map(str, failed_addresses)) - if not errors: - raise ServiceUnavailable( - f"Couldn't connect to {address} (resolved to {address_strs})" - ) - else: - error_strs = "\n".join(map(str, errors)) - raise ServiceUnavailable( - f"Couldn't connect to {address} (resolved to {address_strs}):" - f"\n{error_strs}" - ) from errors[0] - - -class BoltSocket: +class BoltSocketBase: Bolt: te.Final[type[Bolt]] = None # type: ignore[assignment] def __init__(self, socket_: socket): @@ -531,7 +373,9 @@ def kill(self): self._socket.close() @classmethod - def _connect(cls, resolved_address, timeout, keep_alive): + def _connect_secure( + cls, resolved_address, timeout, keep_alive, ssl_context + ): """ Connect to the address and return the socket. @@ -543,175 +387,84 @@ def _connect(cls, resolved_address, timeout, keep_alive): s = None # The socket try: - if len(resolved_address) == 2: - s = socket(AF_INET) - elif len(resolved_address) == 4: - s = socket(AF_INET6) - else: - raise ValueError(f"Unsupported address {resolved_address!r}") try: - s.setsockopt(IPPROTO_TCP, TCP_NODELAY, 1) - except OSError as e: - # option might not be supported on all platforms - if e.errno != errno.ENOPROTOOPT: - raise - t = s.gettimeout() - if timeout: - s.settimeout(timeout) - log.debug("[#0000] C: %s", resolved_address) - s.connect(resolved_address) - s.settimeout(t) - keep_alive = 1 if keep_alive else 0 - s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) - return s - except SocketTimeout: - log.debug("[#0000] S: %s", resolved_address) - log.debug("[#0000] C: %s", resolved_address) - cls._kill_raw_socket(s) - raise ServiceUnavailable( - "Timed out trying to establish connection to " - f"{resolved_address!r}" - ) from None - except Exception as error: - log.debug( - "[#0000] S: %s %s", - type(error).__name__, - " ".join(map(repr, error.args)), - ) - log.debug("[#0000] C: %s", resolved_address) - cls._kill_raw_socket(s) - if isinstance(error, OSError): + if len(resolved_address) == 2: + s = socket(AF_INET) + elif len(resolved_address) == 4: + s = socket(AF_INET6) + else: + raise ValueError( + f"Unsupported address {resolved_address!r}" + ) + try: + s.setsockopt(IPPROTO_TCP, TCP_NODELAY, 1) + except OSError as e: + # option might not be supported on all platforms + if e.errno != errno.ENOPROTOOPT: + raise + t = s.gettimeout() + if timeout: + s.settimeout(timeout) + log.debug("[#0000] C: %s", resolved_address) + s.connect(resolved_address) + s.settimeout(t) + keep_alive = 1 if keep_alive else 0 + s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) + except SocketTimeout: + log.debug("[#0000] S: %s", resolved_address) raise ServiceUnavailable( - f"Failed to establish connection to {resolved_address!r} " - f"(reason {error})" - ) from error - raise - - @classmethod - def _secure(cls, s, host, ssl_context): - local_port = s.getsockname()[1] - # Secure the connection if an SSL context has been provided - if ssl_context: - log.debug("[#%04X] C: %s", local_port, host) - try: - sni_host = host if HAS_SNI and host else None - s = ssl_context.wrap_socket(s, server_hostname=sni_host) - except (OSError, SSLError, CertificateError) as cause: - cls._kill_raw_socket(s) - raise BoltSecurityError( - message="Failed to establish encrypted connection.", - address=(host, local_port), - ) from cause - # Check that the server provides a certificate - der_encoded_server_certificate = s.getpeercert(binary_form=True) - if der_encoded_server_certificate is None: - raise BoltProtocolError( - "When using an encrypted socket, the server should always " - "provide a certificate", - address=(host, local_port), + "Timed out trying to establish connection to " + f"{resolved_address!r}" + ) from None + except Exception as error: + log.debug( + "[#0000] S: %s %s", + type(error).__name__, + " ".join(map(repr, error.args)), ) - return s - return s - - @classmethod - def _handshake(cls, s, resolved_address, deadline): - """ - Perform BOLT handshake. - - :param s: Socket - :param resolved_address: - :param deadline: - - :returns: (socket, version, client_handshake, server_response_data) - """ - local_port = s.getsockname()[1] - - # TODO: Optimize logging code - handshake = cls.Bolt.get_handshake() - handshake = struct.unpack(">16B", handshake) - handshake = [handshake[i : i + 4] for i in range(0, len(handshake), 4)] - - supported_versions = [ - f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}" - for vx in handshake - ] - - log.debug( - "[#%04X] C: 0x%08X", - local_port, - int.from_bytes(cls.Bolt.MAGIC_PREAMBLE, byteorder="big"), - ) - log.debug( - "[#%04X] C: %s %s %s %s", - local_port, - *supported_versions, - ) - - request = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake() + if isinstance(error, OSError): + raise ServiceUnavailable( + "Failed to establish connection to " + f"{resolved_address!r} (reason {error})" + ) from error + raise - # Handle the handshake response - original_timeout = s.gettimeout() - s.settimeout(deadline.to_timeout()) - try: - s.sendall(request) - response = s.recv(4) - except OSError as exc: - raise ServiceUnavailable( - f"Failed to read any data from server {resolved_address!r} " - f"after connected (deadline {deadline})" - ) from exc - finally: - s.settimeout(original_timeout) - data_size = len(response) - if data_size == 0: - # If no data is returned after a successful select - # response, the server has closed the connection - log.debug("[#%04X] S: ", local_port) - cls._kill_raw_socket(s) - raise ServiceUnavailable( - f"Connection to {resolved_address} closed without handshake " - "response" - ) - if data_size != 4: - # Some garbled data has been received - log.debug("[#%04X] S: @*#!", local_port) - cls._kill_raw_socket(s) - raise BoltProtocolError( - "Expected four byte Bolt handshake response from " - f"{resolved_address!r}, received {response!r} instead; " - "check for incorrect port number", - address=resolved_address, - ) - elif response == b"HTTP": - log.debug("[#%04X] S: ", local_port) - cls._kill_raw_socket(s) - raise ServiceUnavailable( - f"Cannot to connect to Bolt service on {resolved_address!r} " - "(looks like HTTP)" - ) - agreed_version = response[-1], response[-2] - log.debug( - "[#%04X] S: 0x%06X%02X", - local_port, - agreed_version[1], - agreed_version[0], - ) - return cls(s), agreed_version, handshake, response + local_port = s.getsockname()[1] + # Secure the connection if an SSL context has been provided + if ssl_context: + hostname = resolved_address._host_name or None + sni_host = hostname if HAS_SNI and hostname else None + log.debug("[#%04X] C: %s", local_port, hostname) + try: + s = ssl_context.wrap_socket(s, server_hostname=sni_host) + except (OSError, SSLError, CertificateError) as cause: + raise BoltSecurityError( + message="Failed to establish encrypted connection.", + address=(hostname, local_port), + ) from cause + # Check that the server provides a certificate + der_encoded_server_certificate = s.getpeercert( + binary_form=True + ) + if der_encoded_server_certificate is None: + raise BoltProtocolError( + "When using an encrypted socket, the server should" + "always provide a certificate", + address=(hostname, local_port), + ) + except Exception: + if s is not None: + log.debug("[#0000] C: %s", resolved_address) + cls._kill_raw_socket(s) + raise - @classmethod - def close_socket(cls, socket_): - if isinstance(socket_, BoltSocket): - socket_ = socket_._socket - cls._kill_raw_socket(socket_) + return cls(s) - @classmethod - def _kill_raw_socket(cls, socket_): - with suppress(OSError): - socket_.shutdown(SHUT_RDWR) - with suppress(OSError): - socket_.close() + @abc.abstractmethod + def _handshake(self, resolved_address, deadline): ... @classmethod + @abc.abstractmethod def connect( cls, address, @@ -721,66 +474,17 @@ def connect( custom_resolver, ssl_context, keep_alive, - ): - """ - Connect and perform a handshake. + ): ... - Return a valid Connection object, assuming a protocol version can be - agreed. - """ - errors = [] - # Establish a connection to the host and port specified - # Catches refused connections see: - # https://docs.python.org/2/library/errno.html + @classmethod + def close_socket(cls, socket_): + if isinstance(socket_, BoltSocketBase): + socket_ = socket_._socket + cls._kill_raw_socket(socket_) - resolved_addresses = NetworkUtil.resolve_address( - addressing.Address(address), resolver=custom_resolver - ) - for resolved_address in resolved_addresses: - deadline_timeout = deadline.to_timeout() - if ( - deadline_timeout is not None - and deadline_timeout <= tcp_timeout - ): - tcp_timeout = deadline_timeout - s = None - try: - s = BoltSocket._connect( - resolved_address, tcp_timeout, keep_alive - ) - s = BoltSocket._secure( - s, resolved_address._host_name, ssl_context - ) - return BoltSocket._handshake(s, resolved_address, deadline) - except (BoltError, DriverError, OSError) as error: - try: - local_port = s.getsockname()[1] - except (OSError, AttributeError): - local_port = 0 - err_str = error.__class__.__name__ - if str(error): - err_str += ": " + str(error) - log.debug( - "[#%04X] S: %s", local_port, err_str - ) - if s: - cls.close_socket(s) - errors.append(error) - except Exception: - if s: - cls.close_socket(s) - raise - if not errors: - resolved_address_strs = tuple(map(str, resolved_addresses)) - raise ServiceUnavailable( - f"Couldn't connect to {address} " - f"(resolved to {resolved_address_strs})" - ) - else: - resolved_address_strs = tuple(map(str, resolved_addresses)) - error_strs = "\n".join(map(str, errors)) - raise ServiceUnavailable( - f"Couldn't connect to {address} " - f"(resolved to {resolved_address_strs}):\n" - f"{error_strs}" - ) from errors[0] + @classmethod + def _kill_raw_socket(cls, socket_): + with suppress(OSError): + socket_.shutdown(SHUT_RDWR) + with suppress(OSError): + socket_.close() diff --git a/src/neo4j/_sync/io/__init__.py b/src/neo4j/_sync/io/__init__.py index 5a7ea8312..c6fc7c496 100644 --- a/src/neo4j/_sync/io/__init__.py +++ b/src/neo4j/_sync/io/__init__.py @@ -32,6 +32,12 @@ ] +# [bolt-version-bump] search tag when changing bolt version support +from . import ( # noqa - imports needed to register protocol handlers + _bolt3, + _bolt4, + _bolt5, +) from ._bolt import Bolt from ._common import ( check_supported_server_product, diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index bcd5a6baf..39b72d9c8 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -23,7 +23,6 @@ from logging import getLogger from time import monotonic -from ..._async_compat.network import BoltSocket from ..._async_compat.util import Util from ..._auth_management import to_auth_dict from ..._codec.hydration import ( @@ -51,6 +50,7 @@ SessionExpired, ) from ..config import PoolConfig +from ._bolt_socket import BoltSocket from ._common import ( CommitResponse, Inbox, @@ -59,6 +59,8 @@ if t.TYPE_CHECKING: + import typing_extensions as te + from ..._api import TelemetryAPI @@ -134,6 +136,8 @@ class Bolt: # results for it. most_recent_qid = None + SKIP_REGISTRATION = False + def __init__( self, unresolved_address, @@ -257,101 +261,36 @@ def assert_notification_filtering_support(self): f"{self.server_info.agent!r}" ) - # [bolt-version-bump] search tag when changing bolt version support - @classmethod - def protocol_handlers(cls, protocol_version=None): - """ - Return a dictionary of available Bolt protocol handlers. - - The handlers are keyed by version tuple. If an explicit protocol - version is provided, the dictionary will contain either zero or one - items, depending on whether that version is supported. If no protocol - version is provided, all available versions will be returned. - - :param protocol_version: tuple identifying a specific protocol - version (e.g. (3, 5)) or None - :returns: dictionary of version tuple to handler class for all - relevant and supported protocol versions - :raise TypeError: if protocol version is not passed in a tuple - """ - # Carry out Bolt subclass imports locally to avoid circular dependency - # issues. - from ._bolt3 import Bolt3 - from ._bolt4 import ( - Bolt4x1, - Bolt4x2, - Bolt4x3, - Bolt4x4, - ) - from ._bolt5 import ( - Bolt5x0, - Bolt5x1, - Bolt5x2, - Bolt5x3, - Bolt5x4, - Bolt5x5, - Bolt5x6, - Bolt5x7, - Bolt5x8, - ) - - handlers = { - Bolt3.PROTOCOL_VERSION: Bolt3, - # 4.0 unsupported because no space left in the handshake - Bolt4x1.PROTOCOL_VERSION: Bolt4x1, - Bolt4x2.PROTOCOL_VERSION: Bolt4x2, - Bolt4x3.PROTOCOL_VERSION: Bolt4x3, - Bolt4x4.PROTOCOL_VERSION: Bolt4x4, - Bolt5x0.PROTOCOL_VERSION: Bolt5x0, - Bolt5x1.PROTOCOL_VERSION: Bolt5x1, - Bolt5x2.PROTOCOL_VERSION: Bolt5x2, - Bolt5x3.PROTOCOL_VERSION: Bolt5x3, - Bolt5x4.PROTOCOL_VERSION: Bolt5x4, - Bolt5x5.PROTOCOL_VERSION: Bolt5x5, - Bolt5x6.PROTOCOL_VERSION: Bolt5x6, - Bolt5x7.PROTOCOL_VERSION: Bolt5x7, - Bolt5x8.PROTOCOL_VERSION: Bolt5x8, - } + protocol_handlers: t.ClassVar[dict[Version, type[Bolt]]] = {} + def __init_subclass__(cls: type[te.Self], **kwargs: t.Any) -> None: + if cls.SKIP_REGISTRATION: + super().__init_subclass__(**kwargs) + return + protocol_version = cls.PROTOCOL_VERSION if protocol_version is None: - return handlers - - if not isinstance(protocol_version, tuple): - raise TypeError("Protocol version must be specified as a tuple") - - if protocol_version in handlers: - return {protocol_version: handlers[protocol_version]} - - return {} - - @classmethod - def version_list(cls, versions): - """ - Return a list of supported protocol versions in order of preference. - - The number of protocol versions (or ranges) returned is limited to 4. - """ - # In fact, 4.3 is the fist version to support ranges. However, the - # range support got backported to 4.2. But even if the server is too - # old to have the backport, negotiating BOLT 4.1 is no problem as it's - # equivalent to 4.2 - first_with_range_support = Version(4, 2) - result = [] - for version in versions: - if ( - result - and version >= first_with_range_support - and result[-1][0] == version[0] - and result[-1][1][1] == version[1] + 1 - ): - # can use range to encompass this version - result[-1][1][1] = version[1] - continue - result.append(Version(version[0], [version[1], version[1]])) - if len(result) == 4: - break - return result + raise ValueError( + "Bolt subclasses must define PROTOCOL_VERSION" + ) + if not ( + isinstance(protocol_version, Version) + and len(protocol_version) == 2 + and all(isinstance(i, int) for i in protocol_version) + ): + raise TypeError( + "PROTOCOL_VERSION must be a 2-tuple of integers, not " + f"{protocol_version!r}" + ) + if protocol_version in Bolt.protocol_handlers: + cls_conflict = Bolt.protocol_handlers[protocol_version] + raise TypeError( + f"Multiple classes for the same protocol version " + f"{protocol_version}: {cls}, {cls_conflict}" + ) + cls.protocol_handlers[protocol_version] = cls + super().__init_subclass__(**kwargs) + # [bolt-version-bump] search tag when changing bolt version support @classmethod def get_handshake(cls): """ @@ -360,12 +299,9 @@ def get_handshake(cls): The length is 16 bytes as specified in the Bolt version negotiation. :returns: bytes """ - supported_versions = sorted( - cls.protocol_handlers().keys(), reverse=True + return ( + b"\x00\x00\x01\xff\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x00\x03" ) - offered_versions = cls.version_list(supported_versions) - versions_bytes = (v.to_bytes() for v in offered_versions) - return b"".join(versions_bytes).ljust(16, b"\x00") @classmethod def ping(cls, address, *, deadline=None, pool_config=None): @@ -400,7 +336,6 @@ def ping(cls, address, *, deadline=None, pool_config=None): BoltSocket.close_socket(s) return protocol_version - # [bolt-version-bump] search tag when changing bolt version support @classmethod def open( cls, @@ -441,71 +376,17 @@ def open( ) pool_config.protocol_version = protocol_version - - # Carry out Bolt subclass imports locally to avoid circular dependency - # issues. - - # avoid new lines after imports for better readability and conciseness - # fmt: off - if protocol_version == (5, 8): - from ._bolt5 import Bolt5x8 - bolt_cls = Bolt5x8 - elif protocol_version == (5, 7): - from ._bolt5 import Bolt5x7 - bolt_cls = Bolt5x7 - elif protocol_version == (5, 6): - from ._bolt5 import Bolt5x6 - bolt_cls = Bolt5x6 - elif protocol_version == (5, 5): - from ._bolt5 import Bolt5x5 - bolt_cls = Bolt5x5 - elif protocol_version == (5, 4): - from ._bolt5 import Bolt5x4 - bolt_cls = Bolt5x4 - elif protocol_version == (5, 3): - from ._bolt5 import Bolt5x3 - bolt_cls = Bolt5x3 - elif protocol_version == (5, 2): - from ._bolt5 import Bolt5x2 - bolt_cls = Bolt5x2 - elif protocol_version == (5, 1): - from ._bolt5 import Bolt5x1 - bolt_cls = Bolt5x1 - elif protocol_version == (5, 0): - from ._bolt5 import Bolt5x0 - bolt_cls = Bolt5x0 - elif protocol_version == (4, 4): - from ._bolt4 import Bolt4x4 - bolt_cls = Bolt4x4 - elif protocol_version == (4, 3): - from ._bolt4 import Bolt4x3 - bolt_cls = Bolt4x3 - elif protocol_version == (4, 2): - from ._bolt4 import Bolt4x2 - bolt_cls = Bolt4x2 - elif protocol_version == (4, 1): - from ._bolt4 import Bolt4x1 - bolt_cls = Bolt4x1 - # Implementation for 4.0 exists, but there was no space left in the - # handshake to offer this version to the server. Hence, the server - # should never request us to speak bolt 4.0. - # elif protocol_version == (4, 0): - # from ._bolt4 import AsyncBolt4x0 - # bolt_cls = AsyncBolt4x0 - elif protocol_version == (3, 0): - from ._bolt3 import Bolt3 - bolt_cls = Bolt3 - # fmt: on - else: + protocol_handlers = Bolt.protocol_handlers + bolt_cls = protocol_handlers.get(protocol_version) + if bolt_cls is None: log.debug("[#%04X] C: ", s.getsockname()[1]) BoltSocket.close_socket(s) - supported_versions = cls.protocol_handlers().keys() # TODO: 6.0 - raise public DriverError subclass instead raise BoltHandshakeError( "The neo4j server does not support communication with this " "driver. This driver has support for Bolt protocols " - f"{tuple(map(str, supported_versions))}.", + f"{tuple(map(str, Bolt.protocol_handlers.keys()))}.", address=address, request_data=handshake, response_data=data, diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 99c04185a..19c719240 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -74,6 +74,8 @@ class Bolt4x0(Bolt): supports_notification_filtering = False + SKIP_REGISTRATION = True + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( @@ -550,6 +552,8 @@ class Bolt4x2(Bolt4x1): PROTOCOL_VERSION = Version(4, 2) + SKIP_REGISTRATION = False + class Bolt4x3(Bolt4x2): """ diff --git a/src/neo4j/_sync/io/_bolt_socket.py b/src/neo4j/_sync/io/_bolt_socket.py new file mode 100644 index 000000000..d07756f70 --- /dev/null +++ b/src/neo4j/_sync/io/_bolt_socket.py @@ -0,0 +1,359 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import asyncio +import dataclasses +import logging +import struct +import typing as t +from contextlib import suppress + +from ... import addressing +from ..._async_compat.network import ( + BoltSocketBase, + NetworkUtil, +) +from ..._exceptions import ( + BoltError, + BoltProtocolError, +) +from ...exceptions import ( + DriverError, + ServiceUnavailable, +) + + +if t.TYPE_CHECKING: + from ..._deadline import Deadline + + +log = logging.getLogger("neo4j.io") + + +@dataclasses.dataclass +class HandshakeCtx: + ctx: str + deadline: Deadline + local_port: int + resolved_address: addressing.ResolvedAddress + full_response: bytearray = dataclasses.field(default_factory=bytearray) + + +@dataclasses.dataclass +class BytesPrinter: + bytes: bytes | bytearray + + def __str__(self): + return f"0x{self.bytes.hex().upper()}" + + +class BoltSocket(BoltSocketBase): + def _parse_handshake_response_v1(self, ctx, response): + agreed_version = response[-1], response[-2] + log.debug( + "[#%04X] S: 0x%06X%02X", + ctx.local_port, + agreed_version[1], + agreed_version[0], + ) + return agreed_version + + def _parse_handshake_response_v2(self, ctx, response): + ctx.ctx = "handshake v2 offerings count" + num_offerings = self._read_varint(ctx) + offerings = [] + for i in range(num_offerings): + ctx.ctx = f"handshake v2 offering {i}" + offering_response = self._handshake_read(ctx, 4) + offering = offering_response[-1:-4:-1] + offerings.append(offering) + ctx.ctx = "handshake v2 capabilities" + _capabilities_offer = self._read_varint(ctx) + + if log.getEffectiveLevel() >= logging.DEBUG: + log.debug( + "[#%04X] S: %s [%i] %s %s", + ctx.local_port, + BytesPrinter(response), + num_offerings, + " ".join( + f"0x{vx[2]:04X}{vx[1]:02X}{vx[0]:02X}" for vx in offerings + ), + BytesPrinter(self._encode_varint(_capabilities_offer)), + ) + + supported_versions = sorted(self.Bolt.protocol_handlers.keys()) + chosen_version = 0, 0 + for v in supported_versions: + for offer_major, offer_minor, offer_range in offerings: + offer_max = (offer_major, offer_minor) + offer_min = (offer_major, offer_minor - offer_range) + if offer_min <= v <= offer_max: + chosen_version = v + break + + ctx.ctx = "handshake v2 chosen version" + self._handshake_send( + ctx, bytes((0, 0, chosen_version[1], chosen_version[0])) + ) + chosen_capabilities = 0 + capabilities = self._encode_varint(chosen_capabilities) + ctx.ctx = "handshake v2 chosen capabilities" + log.debug( + "[#%04X] C: 0x%06X%02X %s", + ctx.local_port, + chosen_version[1], + chosen_version[0], + BytesPrinter(capabilities), + ) + self._handshake_send(ctx, b"\x00") + + return chosen_version + + def _read_varint(self, ctx): + next_byte = (self._handshake_read(ctx, 1))[0] + res = next_byte & 0x7F + i = 0 + while next_byte & 0x80: + i += 1 + next_byte = (self._handshake_read(ctx, 1))[0] + res += (next_byte & 0x7F) << (7 * i) + return res + + @staticmethod + def _encode_varint(n): + res = bytearray() + while n >= 0x80: + res.append(n & 0x7F | 0x80) + n >>= 7 + res.append(n) + return res + + def _handshake_read(self, ctx, n): + original_timeout = self.gettimeout() + self.settimeout(ctx.deadline.to_timeout()) + try: + response = self.recv(n) + ctx.full_response.extend(response) + except OSError as exc: + raise ServiceUnavailable( + f"Failed to read {ctx.ctx} from server " + f"{ctx.resolved_address!r} (deadline {ctx.deadline})" + ) from exc + finally: + self.settimeout(original_timeout) + data_size = len(response) + if data_size == 0: + # If no data is returned after a successful select + # response, the server has closed the connection + log.debug("[#%04X] S: ", ctx.local_port) + self.close() + raise ServiceUnavailable( + f"Connection to {ctx.resolved_address} closed with incomplete " + f"handshake response" + ) + if data_size != n: + # Some garbled data has been received + log.debug("[#%04X] S: @*#!", ctx.local_port) + self.close() + raise BoltProtocolError( + f"Expected {ctx.ctx} from {ctx.resolved_address!r}, received " + f"{response!r} instead (so far {ctx.full_response!r}); " + "check for incorrect port number", + address=ctx.resolved_address, + ) + + return response + + def _handshake_send(self, ctx, data): + original_timeout = self.gettimeout() + self.settimeout(ctx.deadline.to_timeout()) + try: + self.sendall(data) + except OSError as exc: + raise ServiceUnavailable( + f"Failed to write {ctx.ctx} to server " + f"{ctx.resolved_address!r} (deadline {ctx.deadline})" + ) from exc + finally: + self.settimeout(original_timeout) + + def _handshake(self, resolved_address, deadline): + """ + Perform BOLT handshake. + + :param resolved_address: + :param deadline: Deadline for handshake + + :returns: (version, client_handshake, server_response_data) + """ + local_port = self.getsockname()[1] + + if log.getEffectiveLevel() >= logging.DEBUG: + handshake = self.Bolt.get_handshake() + handshake = struct.unpack(">16B", handshake) + handshake = [ + handshake[i : i + 4] for i in range(0, len(handshake), 4) + ] + + supported_versions = [ + f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}" + for vx in handshake + ] + + log.debug( + "[#%04X] C: 0x%08X", + local_port, + int.from_bytes(self.Bolt.MAGIC_PREAMBLE, byteorder="big"), + ) + log.debug( + "[#%04X] C: %s %s %s %s", + local_port, + *supported_versions, + ) + + request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() + + ctx = HandshakeCtx( + ctx="handshake opening", + deadline=deadline, + local_port=local_port, + resolved_address=resolved_address, + ) + + self._handshake_send(ctx, request) + + ctx.ctx = "four byte Bolt handshake response" + response = self._handshake_read(ctx, 4) + + if response == b"HTTP": + log.debug("[#%04X] C: (received b'HTTP')", local_port) + self.close() + raise ServiceUnavailable( + f"Cannot to connect to Bolt service on {resolved_address!r} " + "(looks like HTTP)" + ) + elif response[-1] == 0xFF: + # manifest style handshake + manifest_version = response[-2] + if manifest_version == 0x01: + agreed_version = self._parse_handshake_response_v2( + ctx, + response, + ) + else: + raise BoltProtocolError( + "Unsupported Bolt handshake manifest version " + f"{manifest_version} received from {resolved_address!r}.", + address=resolved_address, + ) + else: + agreed_version = self._parse_handshake_response_v1( + ctx, + response, + ) + + return agreed_version, handshake, response + + @classmethod + def connect( + cls, + address, + *, + tcp_timeout, + deadline, + custom_resolver, + ssl_context, + keep_alive, + ): + """ + Connect and perform a handshake. + + Return a valid Connection object, assuming a protocol version can be + agreed. + """ + errors = [] + failed_addresses = [] + # Establish a connection to the host and port specified + # Catches refused connections see: + # https://docs.python.org/2/library/errno.html + + resolved_addresses = NetworkUtil.resolve_address( + addressing.Address(address), resolver=custom_resolver + ) + for resolved_address in resolved_addresses: + deadline_timeout = deadline.to_timeout() + if ( + deadline_timeout is not None + and deadline_timeout <= tcp_timeout + ): + tcp_timeout = deadline_timeout + s = None + try: + s = cls._connect_secure( + resolved_address, tcp_timeout, keep_alive, ssl_context + ) + agreed_version, handshake, response = s._handshake( + resolved_address, deadline + ) + return s, agreed_version, handshake, response + except (BoltError, DriverError, OSError) as error: + try: + local_port = s.getsockname()[1] + except (OSError, AttributeError, TypeError): + local_port = 0 + err_str = error.__class__.__name__ + if str(error): + err_str += ": " + str(error) + log.debug( + "[#%04X] S: %s %s", + local_port, + resolved_address, + err_str, + ) + if s: + cls.close_socket(s) + errors.append(error) + failed_addresses.append(resolved_address) + except asyncio.CancelledError: + try: + local_port = s.getsockname()[1] + except (OSError, AttributeError, TypeError): + local_port = 0 + log.debug( + "[#%04X] C: %s", local_port, resolved_address + ) + if s: + with suppress(OSError): + s.kill() + raise + except Exception: + if s: + cls.close_socket(s) + raise + address_strs = tuple(map(str, failed_addresses)) + if not errors: + raise ServiceUnavailable( + f"Couldn't connect to {address} (resolved to {address_strs})" + ) + else: + error_strs = "\n".join(map(str, errors)) + raise ServiceUnavailable( + f"Couldn't connect to {address} (resolved to {address_strs}):" + f"\n{error_strs}" + ) from errors[0] diff --git a/src/neo4j/api.py b/src/neo4j/api.py index 3fdc09fce..ce9bcbe5d 100644 --- a/src/neo4j/api.py +++ b/src/neo4j/api.py @@ -415,7 +415,13 @@ def update(self, metadata: dict) -> None: # TODO: 6.0 - this class should not be public. # As far the user is concerned, protocol versions should simply be a # tuple[int, int]. -class Version(tuple): +if t.TYPE_CHECKING: + _version_base = t.Tuple[int, int] +else: + _version_base = tuple + + +class Version(_version_base): def __new__(cls, *v): return super().__new__(cls, v) diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 586616e7c..e960ad8fc 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -190,6 +190,7 @@ def new_driver(backend, data): ("maxTxRetryTimeMs", "max_transaction_retry_time"), ("connectionAcquisitionTimeoutMs", "connection_acquisition_timeout"), ("livenessCheckTimeoutMs", "liveness_check_timeout"), + ("maxConnectionLifetimeMs", "max_connection_lifetime"), ): if data.get(timeout_testkit) is not None: kwargs[timeout_driver] = data[timeout_testkit] / 1000 diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 0fe66cc5a..be37f132e 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -50,7 +50,8 @@ "Feature:Auth:Kerberos": true, "Feature:Auth:Managed": true, "Feature:Bolt:3.0": true, - "Feature:Bolt:4.1": true, + "Feature:Bolt:4.0": "Dropped support legacy protocol version 4.0", + "Feature:Bolt:4.1": "Dropped support legacy protocol version 4.1", "Feature:Bolt:4.2": true, "Feature:Bolt:4.3": true, "Feature:Bolt:4.4": true, @@ -63,6 +64,7 @@ "Feature:Bolt:5.6": true, "Feature:Bolt:5.7": true, "Feature:Bolt:5.8": true, + "Feature:Bolt:HandshakeManifestV1": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", diff --git a/tests/_async_compat/__init__.py b/tests/_async_compat/__init__.py index 8170965ac..d4a9e1c45 100644 --- a/tests/_async_compat/__init__.py +++ b/tests/_async_compat/__init__.py @@ -17,7 +17,9 @@ from functools import wraps as _wraps from .mark_decorator import ( + async_fixture, AsyncTestDecorators, + fixture, mark_async_test, mark_sync_test, TestDecorators, @@ -27,6 +29,8 @@ __all__ = [ "AsyncTestDecorators", "TestDecorators", + "async_fixture", + "fixture", "mark_async_test", "mark_sync_test", "wrap_async", diff --git a/tests/_async_compat/mark_decorator.py b/tests/_async_compat/mark_decorator.py index 69a762b64..d0c35b80d 100644 --- a/tests/_async_compat/mark_decorator.py +++ b/tests/_async_compat/mark_decorator.py @@ -15,15 +15,21 @@ import pytest +import pytest_asyncio mark_async_test = pytest.mark.asyncio +async_fixture = pytest_asyncio.fixture + def mark_sync_test(f): return f +fixture = pytest.fixture + + class AsyncTestDecorators: mark_async_only_test = mark_async_test diff --git a/tests/conftest.py b/tests/conftest.py index be756c8ad..b6cdd7f7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -202,3 +202,53 @@ def _(): def watcher(): with watch("neo4j", out=sys.stdout, colour=True): yield + + +# TODO: 6.0 - +# when support for Python 3.7 is dropped and pytest-asyncio is bumped +# check if this fixture is still needed +@pytest.fixture +def event_loop(): + # Overwriting the default event loop injected by pytest-asyncio + # because its implementation doesn't properly shut down the loop + # (e.g., it doesn't call `shutdown_asyncgens`) + policy = asyncio.get_event_loop_policy() + loop = policy.new_event_loop() + yield loop + try: + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + if sys.version_info >= (3, 9): + loop.run_until_complete(loop.shutdown_default_executor()) + finally: + loop.close() + + +def _cancel_all_tasks(loop): + # Copied from Python 3.13's asyncio package with minor modifications + # in exception wording and variable naming + + # Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, + # 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022 + # Python Software Foundation; + # All Rights Reserved + to_cancel = asyncio.all_tasks(loop) + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during loop shutdown", + "exception": task.exception(), + "task": task, + } + ) diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 4469b045c..e8b9e8886 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -20,7 +20,7 @@ import neo4j.auth_management from neo4j._async.io import AsyncBolt -from neo4j._async_compat.network import AsyncBoltSocket +from neo4j._async.io._bolt_socket import AsyncBoltSocket from neo4j._exceptions import BoltHandshakeError from ...._async_compat import ( @@ -29,20 +29,17 @@ ) -# python -m pytest tests/unit/io/test_class_bolt.py -s -v - - # [bolt-version-bump] search tag when changing bolt version support def test_class_method_protocol_handlers(): # fmt: off expected_handlers = { (3, 0), - (4, 1), (4, 2), (4, 3), (4, 4), + (4, 2), (4, 3), (4, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on - protocol_handlers = AsyncBolt.protocol_handlers() + protocol_handlers = AsyncBolt.protocol_handlers assert len(protocol_handlers) == len(expected_handlers) assert protocol_handlers.keys() == expected_handlers @@ -57,7 +54,7 @@ def test_class_method_protocol_handlers(): ((2, 0), 0), ((3, 0), 1), ((4, 0), 0), - ((4, 1), 1), + ((4, 1), 0), ((4, 2), 1), ((4, 3), 1), ((4, 4), 1), @@ -77,15 +74,8 @@ def test_class_method_protocol_handlers(): def test_class_method_protocol_handlers_with_protocol_version( test_input, expected ): - protocol_handlers = AsyncBolt.protocol_handlers( - protocol_version=test_input - ) - assert len(protocol_handlers) == expected - - -def test_class_method_protocol_handlers_with_invalid_protocol_version(): - with pytest.raises(TypeError): - AsyncBolt.protocol_handlers(protocol_version=2) + protocol_handlers = AsyncBolt.protocol_handlers + assert (test_input in protocol_handlers) == expected # [bolt-version-bump] search tag when changing bolt version support @@ -93,7 +83,7 @@ def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() assert ( handshake - == b"\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x00\x01\xff\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x00\x03" ) @@ -120,6 +110,10 @@ async def test_cancel_hello_in_open(mocker, none_auth): bolt_mock.socket = socket_mock bolt_mock.hello.side_effect = asyncio.CancelledError() bolt_mock.local_port = 1234 + mocker.patch.dict( + AsyncBolt.protocol_handlers, + {(5, 0): bolt_cls_mock}, + ) with pytest.raises(asyncio.CancelledError): await AsyncBolt.open(address, auth_manager=none_auth) @@ -132,6 +126,7 @@ async def test_cancel_hello_in_open(mocker, none_auth): ("bolt_version", "bolt_cls_path"), ( ((3, 0), "neo4j._async.io._bolt3.AsyncBolt3"), + ((4, 0), "neo4j._async.io._bolt4.AsyncBolt4x0"), ((4, 1), "neo4j._async.io._bolt4.AsyncBolt4x1"), ((4, 2), "neo4j._async.io._bolt4.AsyncBolt4x2"), ((4, 3), "neo4j._async.io._bolt4.AsyncBolt4x3"), @@ -168,6 +163,10 @@ async def test_version_negotiation( bolt_cls_mock.return_value.local_port = 1234 bolt_mock = bolt_cls_mock.return_value bolt_mock.socket = socket_mock + mocker.patch.dict( + AsyncBolt.protocol_handlers, + {bolt_version: bolt_cls_mock}, + ) connection = await AsyncBolt.open(address, auth_manager=none_auth) @@ -180,9 +179,11 @@ async def test_version_negotiation( "bolt_version", ( (0, 0), + (1, 0), (2, 0), - (4, 0), (3, 1), + (4, 0), + (4, 1), (5, 9), (6, 0), ), @@ -190,8 +191,8 @@ async def test_version_negotiation( @mark_async_test async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( - "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7')" + "('3.0', '4.2', '4.3', '4.4', " + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8')" ) address = ("localhost", 7687) @@ -224,12 +225,6 @@ async def test_cancel_manager_in_open(mocker): ) socket_cls_mock.connect.return_value = (socket_mock, (5, 0), None, None) socket_mock.getpeername.return_value = address - bolt_cls_mock = mocker.patch( - "neo4j._async.io._bolt5.AsyncBolt5x0", autospec=True - ) - bolt_mock = bolt_cls_mock.return_value - bolt_mock.socket = socket_mock - bolt_mock.local_port = 1234 auth_manager = mocker.AsyncMock( spec=neo4j.auth_management.AsyncAuthManager @@ -252,12 +247,6 @@ async def test_fail_manager_in_open(mocker): ) socket_cls_mock.connect.return_value = (socket_mock, (5, 0), None, None) socket_mock.getpeername.return_value = address - bolt_cls_mock = mocker.patch( - "neo4j._async.io._bolt5.AsyncBolt5x0", autospec=True - ) - bolt_mock = bolt_cls_mock.return_value - bolt_mock.socket = socket_mock - bolt_mock.local_port = 1234 auth_manager = mocker.AsyncMock( spec=neo4j.auth_management.AsyncAuthManager diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index c6add31f4..47f6f813d 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -31,7 +31,10 @@ ServiceUnavailable, ) -from ...._async_compat import mark_async_test +from ...._async_compat import ( + async_fixture, + mark_async_test, +) class AsyncFakeBoltPool(AsyncIOPool): @@ -113,7 +116,7 @@ async def test_bolt_connection_ping_timeout(): assert protocol_version is None -@pytest.fixture +@async_fixture async def pool(async_fake_connection_generator): async with AsyncFakeBoltPool( async_fake_connection_generator, ("127.0.0.1", 7687) diff --git a/tests/unit/common/test_exceptions.py b/tests/unit/common/test_exceptions.py index dfb444a01..fc880bbb5 100644 --- a/tests/unit/common/test_exceptions.py +++ b/tests/unit/common/test_exceptions.py @@ -82,7 +82,7 @@ def test_bolt_handshake_error(): b"\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00" ) response = b"\x00\x00\x00\x00" - supported_versions = Bolt.protocol_handlers().keys() + supported_versions = Bolt.protocol_handlers.keys() with pytest.raises(BoltHandshakeError) as e: error = BoltHandshakeError( diff --git a/tests/unit/mixed/async_compat/test_network.py b/tests/unit/mixed/async_compat/test_network.py index ee4c290d9..c3ffee811 100644 --- a/tests/unit/mixed/async_compat/test_network.py +++ b/tests/unit/mixed/async_compat/test_network.py @@ -23,7 +23,7 @@ import freezegun import pytest -from neo4j._async_compat.network import AsyncBoltSocket +from neo4j._async.io._bolt_socket import AsyncBoltSocket from neo4j._exceptions import SocketDeadlineExceededError from ...._async_compat.mark_decorator import mark_async_test diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index c3d6bace5..25a22ecb4 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -19,9 +19,9 @@ import pytest import neo4j.auth_management -from neo4j._async_compat.network import BoltSocket from neo4j._exceptions import BoltHandshakeError from neo4j._sync.io import Bolt +from neo4j._sync.io._bolt_socket import BoltSocket from ...._async_compat import ( mark_sync_test, @@ -29,20 +29,17 @@ ) -# python -m pytest tests/unit/io/test_class_bolt.py -s -v - - # [bolt-version-bump] search tag when changing bolt version support def test_class_method_protocol_handlers(): # fmt: off expected_handlers = { (3, 0), - (4, 1), (4, 2), (4, 3), (4, 4), + (4, 2), (4, 3), (4, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on - protocol_handlers = Bolt.protocol_handlers() + protocol_handlers = Bolt.protocol_handlers assert len(protocol_handlers) == len(expected_handlers) assert protocol_handlers.keys() == expected_handlers @@ -57,7 +54,7 @@ def test_class_method_protocol_handlers(): ((2, 0), 0), ((3, 0), 1), ((4, 0), 0), - ((4, 1), 1), + ((4, 1), 0), ((4, 2), 1), ((4, 3), 1), ((4, 4), 1), @@ -77,15 +74,8 @@ def test_class_method_protocol_handlers(): def test_class_method_protocol_handlers_with_protocol_version( test_input, expected ): - protocol_handlers = Bolt.protocol_handlers( - protocol_version=test_input - ) - assert len(protocol_handlers) == expected - - -def test_class_method_protocol_handlers_with_invalid_protocol_version(): - with pytest.raises(TypeError): - Bolt.protocol_handlers(protocol_version=2) + protocol_handlers = Bolt.protocol_handlers + assert (test_input in protocol_handlers) == expected # [bolt-version-bump] search tag when changing bolt version support @@ -93,7 +83,7 @@ def test_class_method_get_handshake(): handshake = Bolt.get_handshake() assert ( handshake - == b"\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x00\x01\xff\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x00\x03" ) @@ -120,6 +110,10 @@ def test_cancel_hello_in_open(mocker, none_auth): bolt_mock.socket = socket_mock bolt_mock.hello.side_effect = asyncio.CancelledError() bolt_mock.local_port = 1234 + mocker.patch.dict( + Bolt.protocol_handlers, + {(5, 0): bolt_cls_mock}, + ) with pytest.raises(asyncio.CancelledError): Bolt.open(address, auth_manager=none_auth) @@ -132,6 +126,7 @@ def test_cancel_hello_in_open(mocker, none_auth): ("bolt_version", "bolt_cls_path"), ( ((3, 0), "neo4j._sync.io._bolt3.Bolt3"), + ((4, 0), "neo4j._sync.io._bolt4.Bolt4x0"), ((4, 1), "neo4j._sync.io._bolt4.Bolt4x1"), ((4, 2), "neo4j._sync.io._bolt4.Bolt4x2"), ((4, 3), "neo4j._sync.io._bolt4.Bolt4x3"), @@ -168,6 +163,10 @@ def test_version_negotiation( bolt_cls_mock.return_value.local_port = 1234 bolt_mock = bolt_cls_mock.return_value bolt_mock.socket = socket_mock + mocker.patch.dict( + Bolt.protocol_handlers, + {bolt_version: bolt_cls_mock}, + ) connection = Bolt.open(address, auth_manager=none_auth) @@ -180,9 +179,11 @@ def test_version_negotiation( "bolt_version", ( (0, 0), + (1, 0), (2, 0), - (4, 0), (3, 1), + (4, 0), + (4, 1), (5, 9), (6, 0), ), @@ -190,8 +191,8 @@ def test_version_negotiation( @mark_sync_test def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( - "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7')" + "('3.0', '4.2', '4.3', '4.4', " + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8')" ) address = ("localhost", 7687) @@ -224,12 +225,6 @@ def test_cancel_manager_in_open(mocker): ) socket_cls_mock.connect.return_value = (socket_mock, (5, 0), None, None) socket_mock.getpeername.return_value = address - bolt_cls_mock = mocker.patch( - "neo4j._sync.io._bolt5.Bolt5x0", autospec=True - ) - bolt_mock = bolt_cls_mock.return_value - bolt_mock.socket = socket_mock - bolt_mock.local_port = 1234 auth_manager = mocker.MagicMock( spec=neo4j.auth_management.AuthManager @@ -252,12 +247,6 @@ def test_fail_manager_in_open(mocker): ) socket_cls_mock.connect.return_value = (socket_mock, (5, 0), None, None) socket_mock.getpeername.return_value = address - bolt_cls_mock = mocker.patch( - "neo4j._sync.io._bolt5.Bolt5x0", autospec=True - ) - bolt_mock = bolt_cls_mock.return_value - bolt_mock.socket = socket_mock - bolt_mock.local_port = 1234 auth_manager = mocker.MagicMock( spec=neo4j.auth_management.AuthManager diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index 64b7d9b52..96acf37e4 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -31,7 +31,10 @@ ServiceUnavailable, ) -from ...._async_compat import mark_sync_test +from ...._async_compat import ( + fixture, + mark_sync_test, +) class FakeBoltPool(IOPool): @@ -113,7 +116,7 @@ def test_bolt_connection_ping_timeout(): assert protocol_version is None -@pytest.fixture +@fixture def pool(fake_connection_generator): with FakeBoltPool( fake_connection_generator, ("127.0.0.1", 7687) diff --git a/tox.ini b/tox.ini index c6ef5b5d6..c6f46f7c8 100644 --- a/tox.ini +++ b/tox.ini @@ -11,7 +11,7 @@ setenv = TEST_SUITE_NAME={envname} usedevelop = true warnargs = - py{37,38,39,310,311,312}: -W error + -W error -W ignore::pytest.PytestUnraisableExceptionWarning parallel_show_output = true commands = coverage erase