Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/neo4j/_async/io/_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def __init_subclass__(cls: type[te.Self], **kwargs: t.Any) -> None:

# [bolt-version-bump] search tag when changing bolt version support
@classmethod
def get_handshake(cls):
def get_handshake(cls) -> bytes:
"""
Return the supported Bolt versions as bytes.

Expand Down
78 changes: 49 additions & 29 deletions src/neo4j/_async/io/_bolt_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,15 @@


if t.TYPE_CHECKING:
from ssl import SSLContext

import typing_extensions as te

from ..._deadline import Deadline
from ...addressing import (
Address,
ResolvedAddress,
)


log = logging.getLogger("neo4j.io")
Expand All @@ -63,7 +71,11 @@ def __str__(self):


class AsyncBoltSocket(AsyncBoltSocketBase):
async def _parse_handshake_response_v1(self, ctx, response):
async def _parse_handshake_response_v1(
self,
ctx: HandshakeCtx,
response: bytes,
) -> tuple[int, int]:
agreed_version = response[-1], response[-2]
log.debug(
"[#%04X] S: <HANDSHAKE> 0x%06X%02X",
Expand All @@ -73,7 +85,11 @@ async def _parse_handshake_response_v1(self, ctx, response):
)
return agreed_version

async def _parse_handshake_response_v2(self, ctx, response):
async def _parse_handshake_response_v2(
self,
ctx: HandshakeCtx,
response: bytes,
) -> tuple[int, int]:
ctx.ctx = "handshake v2 offerings count"
num_offerings = await self._read_varint(ctx)
offerings = []
Expand All @@ -85,7 +101,7 @@ async def _parse_handshake_response_v2(self, ctx, response):
ctx.ctx = "handshake v2 capabilities"
_capabilities_offer = await self._read_varint(ctx)

if log.getEffectiveLevel() >= logging.DEBUG:
if log.getEffectiveLevel() <= logging.DEBUG:
log.debug(
"[#%04X] S: <HANDSHAKE> %s [%i] %s %s",
ctx.local_port,
Expand Down Expand Up @@ -125,7 +141,7 @@ async def _parse_handshake_response_v2(self, ctx, response):

return chosen_version

async def _read_varint(self, ctx):
async def _read_varint(self, ctx: HandshakeCtx) -> int:
next_byte = (await self._handshake_read(ctx, 1))[0]
res = next_byte & 0x7F
i = 0
Expand All @@ -136,15 +152,15 @@ async def _read_varint(self, ctx):
return res

@staticmethod
def _encode_varint(n):
def _encode_varint(n: int) -> bytearray:
res = bytearray()
while n >= 0x80:
res.append(n & 0x7F | 0x80)
n >>= 7
res.append(n)
return res

async def _handshake_read(self, ctx, n):
async def _handshake_read(self, ctx: HandshakeCtx, n: int) -> bytes:
original_timeout = self.gettimeout()
self.settimeout(ctx.deadline.to_timeout())
try:
Expand Down Expand Up @@ -193,7 +209,11 @@ async def _handshake_send(self, ctx, data):
finally:
self.settimeout(original_timeout)

async def _handshake(self, resolved_address, deadline):
async def _handshake(
self,
resolved_address: ResolvedAddress,
deadline: Deadline,
) -> tuple[tuple[int, int], bytes, bytes]:
"""
Perform BOLT handshake.

Expand All @@ -204,16 +224,16 @@ async def _handshake(self, resolved_address, deadline):
"""
local_port = self.getsockname()[1]

if log.getEffectiveLevel() >= logging.DEBUG:
handshake = self.Bolt.get_handshake()
handshake = struct.unpack(">16B", handshake)
handshake = [
handshake[i : i + 4] for i in range(0, len(handshake), 4)
handshake = self.Bolt.get_handshake()
if log.getEffectiveLevel() <= logging.DEBUG:
handshake_bytes: t.Sequence = struct.unpack(">16B", handshake)
handshake_bytes = [
handshake[i : i + 4] for i in range(0, len(handshake_bytes), 4)
]

supported_versions = [
f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}"
for vx in handshake
for vx in handshake_bytes
]

log.debug(
Expand All @@ -227,7 +247,7 @@ async def _handshake(self, resolved_address, deadline):
*supported_versions,
)

request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake()
request = self.Bolt.MAGIC_PREAMBLE + handshake

ctx = HandshakeCtx(
ctx="handshake opening",
Expand Down Expand Up @@ -273,14 +293,14 @@ async def _handshake(self, resolved_address, deadline):
@classmethod
async def connect(
cls,
address,
address: Address,
*,
tcp_timeout,
deadline,
custom_resolver,
ssl_context,
keep_alive,
):
tcp_timeout: float | None,
deadline: Deadline,
custom_resolver: t.Callable | None,
ssl_context: SSLContext | None,
keep_alive: bool,
) -> tuple[te.Self, tuple[int, int], bytes, bytes]:
"""
Connect and perform a handshake.

Expand Down Expand Up @@ -313,10 +333,10 @@ async def connect(
)
return s, agreed_version, handshake, response
except (BoltError, DriverError, OSError) as error:
try:
local_port = s.getsockname()[1]
except (OSError, AttributeError, TypeError):
local_port = 0
local_port = 0
if isinstance(s, cls):
with suppress(OSError, AttributeError, TypeError):
local_port = s.getsockname()[1]
err_str = error.__class__.__name__
if str(error):
err_str += ": " + str(error)
Expand All @@ -331,10 +351,10 @@ async def connect(
errors.append(error)
failed_addresses.append(resolved_address)
except asyncio.CancelledError:
try:
local_port = s.getsockname()[1]
except (OSError, AttributeError, TypeError):
local_port = 0
local_port = 0
if isinstance(s, cls):
with suppress(OSError, AttributeError, TypeError):
local_port = s.getsockname()[1]
log.debug(
"[#%04X] C: <CANCELED> %s", local_port, resolved_address
)
Expand Down
6 changes: 3 additions & 3 deletions src/neo4j/_async_compat/network/_bolt_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _sanitize_deadline(deadline):
class AsyncBoltSocketBase(abc.ABC):
Bolt: te.Final[type[AsyncBolt]] = None # type: ignore[assignment]

def __init__(self, reader, protocol, writer):
def __init__(self, reader, protocol, writer) -> None:
self._reader = reader # type: asyncio.StreamReader
self._protocol = protocol # type: asyncio.StreamReaderProtocol
self._writer = writer # type: asyncio.StreamWriter
Expand Down Expand Up @@ -171,7 +171,7 @@ def kill(self):
@classmethod
async def _connect_secure(
cls, resolved_address, timeout, keep_alive, ssl_context
):
) -> te.Self:
"""
Connect to the address and return the socket.

Expand Down Expand Up @@ -202,7 +202,7 @@ async def _connect_secure(
keep_alive = 1 if keep_alive else 0
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)

ssl_kwargs = {}
ssl_kwargs: dict[str, t.Any] = {}

if ssl_context is not None:
hostname = resolved_address._host_name or None
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j/_sync/io/_bolt.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

78 changes: 49 additions & 29 deletions src/neo4j/_sync/io/_bolt_socket.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading