diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f379bf89..973aec661 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -190,6 +190,9 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. - On failed liveness check (s. `liveness_check_timeout` configuration option), the driver will no longer remove the remote from the cached routing tables, but only close the connection under test. This aligns the driver with the other official Neo4j drivers. +- The driver incorrectly applied a timeout hint received from the server to both read and write I/O operations. + It is now only applied to read I/O operations. + In turn, a new configuration option `connection_write_timeout` with a default value of `30 seconds` is introduced. ## Version 5.28 diff --git a/docs/source/api.rst b/docs/source/api.rst index b9c77fd34..67d8ecf82 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -399,6 +399,7 @@ Additional configuration can be provided via the :class:`neo4j.Driver` construct + :ref:`connection-acquisition-timeout-ref` + :ref:`connection-timeout-ref` ++ :ref:`connection-write-timeout-ref` + :ref:`encrypted-ref` + :ref:`keep-alive-ref` + :ref:`max-connection-lifetime-ref` @@ -430,7 +431,7 @@ it should be chosen larger than :ref:`connection-timeout-ref`. :Type: ``float`` :Default: ``60.0`` -.. versionadded:: 6.0 +.. versionchanged:: 6.0 The setting now entails *anything* required to acquire a connection. This includes potential fetching of routing tables which in itself requires acquiring a connection. Previously, the timeout would be restarted for such auxiliary connection acquisitions. @@ -450,6 +451,16 @@ connection can be used to perform database related work. :Default: ``30.0`` +.. _connection-write-timeout-ref: + +``connection_write_timeout`` +---------------------------- +The maximum amount of time in seconds to wait for TCP write operations to complete. + +:Type: ``float`` +:Default: ``30.0`` + + .. _encrypted-ref: ``encrypted`` diff --git a/src/neo4j/_async/config.py b/src/neo4j/_async/config.py index 5468d20d8..d738c6f6f 100644 --- a/src/neo4j/_async/config.py +++ b/src/neo4j/_async/config.py @@ -54,6 +54,10 @@ class AsyncPoolConfig(Config): # The maximum amount of time to wait for a TCP connection to be # established. + #: Connection Write Timeout + connection_write_timeout = 30.0 # seconds + # The maximum amount of time to wait for I/O write operations to complete. + #: Custom Resolver resolver = None # Custom resolver function, returning list of resolved addresses. diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index ea82d37bf..4ac2ab8a2 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -127,6 +127,7 @@ def driver( liveness_check_timeout: float | None = ..., max_connection_pool_size: int = ..., connection_timeout: float = ..., + connection_write_timeout: float = ..., resolver: ( t.Callable[[Address], t.Iterable[Address]] | t.Callable[[Address], t.Awaitable[t.Iterable[Address]]] diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index f6f0ba454..3c2d9e73b 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -418,11 +418,13 @@ async def open( ) try: - connection.socket.set_deadline(deadline) + connection.socket.set_read_deadline(deadline) + connection.socket.set_write_deadline(deadline) try: await connection.hello() finally: - connection.socket.set_deadline(None) + connection.socket.set_read_deadline(None) + connection.socket.set_write_deadline(None) except ( Exception, # Python 3.8+: CancelledError is a subclass of BaseException diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index beeaeef04..4072d6a3b 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -626,7 +626,7 @@ def on_success(metadata): "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: - self.socket.settimeout(recv_timeout) + self.socket.set_read_timeout(recv_timeout) else: log.info( "[#%04X] _: Server supplied an " diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 0648e5a0e..d24a59a9c 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -149,7 +149,7 @@ def on_success(metadata): "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: - self.socket.settimeout(recv_timeout) + self.socket.set_read_timeout(recv_timeout) else: log.info( "[#%04X] _: Server supplied an " @@ -622,7 +622,7 @@ def on_success(metadata): "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: - self.socket.settimeout(recv_timeout) + self.socket.set_read_timeout(recv_timeout) else: log.info( "[#%04X] _: Server supplied an " @@ -708,7 +708,7 @@ def on_success(metadata): "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: - self.socket.settimeout(recv_timeout) + self.socket.set_read_timeout(recv_timeout) else: log.info( "[#%04X] _: Server supplied an " diff --git a/src/neo4j/_async/io/_bolt6.py b/src/neo4j/_async/io/_bolt6.py index 311c954e7..06c77fc40 100644 --- a/src/neo4j/_async/io/_bolt6.py +++ b/src/neo4j/_async/io/_bolt6.py @@ -158,7 +158,7 @@ def on_success(metadata): "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: - self.socket.settimeout(recv_timeout) + self.socket.set_read_timeout(recv_timeout) else: log.info( "[#%04X] _: Server supplied an " diff --git a/src/neo4j/_async/io/_bolt_socket.py b/src/neo4j/_async/io/_bolt_socket.py index dfac4a32b..c33062e9d 100644 --- a/src/neo4j/_async/io/_bolt_socket.py +++ b/src/neo4j/_async/io/_bolt_socket.py @@ -32,7 +32,10 @@ BoltError, BoltProtocolError, ) -from ..._io import BoltProtocolVersion +from ..._io import ( + BoltProtocolVersion, + min_timeout, +) from ...exceptions import ( DriverError, ServiceUnavailable, @@ -157,8 +160,8 @@ def _encode_varint(n: int) -> bytearray: return res async def _handshake_read(self, ctx: HandshakeCtx, n: int) -> bytes: - original_timeout = self.gettimeout() - self.settimeout(ctx.deadline.to_timeout()) + original_timeout = self.get_read_timeout() + self.set_read_timeout(ctx.deadline.to_timeout()) try: response = await self.recv(n) ctx.full_response.extend(response) @@ -168,7 +171,7 @@ async def _handshake_read(self, ctx: HandshakeCtx, n: int) -> bytes: f"{ctx.resolved_address!r} (deadline {ctx.deadline})" ) from exc finally: - self.settimeout(original_timeout) + self.set_read_timeout(original_timeout) data_size = len(response) if data_size == 0: # If no data is returned after a successful select @@ -192,9 +195,11 @@ async def _handshake_read(self, ctx: HandshakeCtx, n: int) -> bytes: return response - async def _handshake_send(self, ctx, data): - original_timeout = self.gettimeout() - self.settimeout(ctx.deadline.to_timeout()) + async def _handshake_send(self, ctx, data, write_timeout=None): + original_timeout = self.get_write_timeout() + self.set_write_timeout( + min_timeout(ctx.deadline.to_timeout(), write_timeout) + ) try: await self.sendall(data) except OSError as exc: @@ -203,7 +208,7 @@ async def _handshake_send(self, ctx, data): f"{ctx.resolved_address!r} (deadline {ctx.deadline})" ) from exc finally: - self.settimeout(original_timeout) + self.set_write_timeout(original_timeout) async def _handshake( self, diff --git a/src/neo4j/_async_compat/network/_bolt_socket.py b/src/neo4j/_async_compat/network/_bolt_socket.py index 7ed82e528..c93781e45 100644 --- a/src/neo4j/_async_compat/network/_bolt_socket.py +++ b/src/neo4j/_async_compat/network/_bolt_socket.py @@ -72,6 +72,14 @@ def _sanitize_deadline(deadline): return deadline +def _sanitize_timeout(timeout): + if timeout is None: + return timeout + else: + assert timeout >= 0 + return timeout + + class AsyncBoltSocketBase(abc.ABC): Bolt: t.Final[type[AsyncBolt]] = None # type: ignore[assignment] @@ -79,19 +87,42 @@ def __init__(self, reader, protocol, writer) -> None: self._reader = reader # type: asyncio.StreamReader self._protocol = protocol # type: asyncio.StreamReaderProtocol self._writer = writer # type: asyncio.StreamWriter + self._read_deadline = None + self._write_deadline = None # 0 - non-blocking # None infinitely blocking # int - seconds to wait for data - self._timeout = None - self._deadline = None + self._read_timeout = None + self._write_timeout = None + + def _wait_for_read(self, io_async_fn, *args, **kwargs): + return self._wait_for_io( + "read", + self._read_timeout, + self._read_deadline, + io_async_fn, + *args, + **kwargs, + ) + + def _wait_for_write(self, io_async_fn, *args, **kwargs): + return self._wait_for_io( + "write", + self._write_timeout, + self._write_deadline, + io_async_fn, + *args, + **kwargs, + ) - async def _wait_for_io(self, io_async_fn, *args, **kwargs): - timeout = self._timeout + async def _wait_for_io( + self, name, timeout, deadline, io_async_fn, *args, **kwargs + ): to_raise = TimeoutError - if self._deadline is not None: - deadline_timeout = self._deadline.to_timeout() + if deadline is not None: + deadline_timeout = deadline.to_timeout() if deadline_timeout <= 0: - raise SocketDeadlineExceededError("timed out") + raise SocketDeadlineExceededError(f"{name} timed out") if timeout is None or deadline_timeout <= timeout: timeout = deadline_timeout to_raise = SocketDeadlineExceededError @@ -111,13 +142,31 @@ async def _wait_for_io(self, io_async_fn, *args, **kwargs): try: return await wait_for(io_fut, timeout) except asyncio.TimeoutError as e: - raise to_raise("timed out") from e + raise to_raise(f"{name} timed out") from e + + def get_read_deadline(self): + return self._read_deadline + + def set_read_deadline(self, deadline): + self._read_deadline = _sanitize_deadline(deadline) + + def get_write_deadline(self): + return self._write_deadline + + def set_write_deadline(self, deadline): + self._write_deadline = _sanitize_deadline(deadline) - def get_deadline(self): - return self._deadline + def get_read_timeout(self): + return self._read_timeout - def set_deadline(self, deadline): - self._deadline = _sanitize_deadline(deadline) + def set_read_timeout(self, timeout): + self._read_timeout = _sanitize_timeout(timeout) + + def get_write_timeout(self): + return self._write_timeout + + def set_write_timeout(self, timeout): + self._write_timeout = _sanitize_timeout(timeout) @property def _socket(self) -> socket: @@ -134,28 +183,18 @@ def getpeercert(self, *args, **kwargs): *args, **kwargs ) - def gettimeout(self): - return self._timeout - - def settimeout(self, timeout): - if timeout is None: - self._timeout = timeout - else: - assert timeout >= 0 - self._timeout = timeout - async def recv(self, n): - return await self._wait_for_io(self._reader.read, n) + return await self._wait_for_read(self._reader.read, n) async def recv_into(self, buffer, nbytes): # FIXME: not particularly memory or time efficient - res = await self._wait_for_io(self._reader.read, nbytes) + res = await self._wait_for_read(self._reader.read, nbytes) buffer[: len(res)] = res return len(res) async def sendall(self, data): self._writer.write(data) - return await self._wait_for_io(self._writer.drain) + return await self._wait_for_write(self._writer.drain) async def close(self): self._writer.close() @@ -247,6 +286,12 @@ async def _connect_secure( cls._kill_raw_socket(s) raise except (SSLError, CertificateError) as error: + log.debug( + "[#0000] S: %s: %s", + resolved_address, + error, + ) + log.debug("[#0000] C: %s", resolved_address) if s: cls._kill_raw_socket(s) raise BoltSecurityError( @@ -310,7 +355,10 @@ class BoltSocketBase: def __init__(self, socket_: socket): self._socket = socket_ - self._deadline = None + self._read_deadline = None + self._write_deadline = None + self._read_timeout = None + self._write_timeout = None @property def _socket(self): @@ -325,46 +373,87 @@ def _socket(self, socket_: socket | SSLSocket): self.getpeercert = t.cast(SSLSocket, socket_).getpeercert elif "getpeercert" in self.__dict__: del self.__dict__["getpeercert"] - self.gettimeout = socket_.gettimeout - self.settimeout = socket_.settimeout getsockname: t.Callable = None # type: ignore getpeername: t.Callable = None # type: ignore getpeercert: t.Callable = None # type: ignore - gettimeout: t.Callable = None # type: ignore - settimeout: t.Callable = None # type: ignore - def _wait_for_io(self, func, *args, **kwargs): - if self._deadline is None: + def _wait_for_read(self, func, *args, **kwargs): + return self._wait_for_io( + "read", + self._read_timeout, + self._read_deadline, + func, + *args, + **kwargs, + ) + + def _wait_for_write(self, func, *args, **kwargs): + return self._wait_for_io( + "write", + self._write_timeout, + self._write_deadline, + func, + *args, + **kwargs, + ) + + def _wait_for_io(self, name, timeout, deadline, func, *args, **kwargs): + if deadline is None: + deadline_timeout = None + else: + deadline_timeout = deadline.to_timeout() + if deadline_timeout <= 0: + raise SocketDeadlineExceededError(f"{name} timed out") + if deadline_timeout is not None and ( + timeout is None or deadline_timeout <= timeout + ): + effective_timeout = deadline_timeout + rewrite_error = True + else: + effective_timeout = timeout + rewrite_error = False + + self._socket.settimeout(effective_timeout) + try: return func(*args, **kwargs) - timeout = self._socket.gettimeout() - deadline_timeout = self._deadline.to_timeout() - if deadline_timeout <= 0: - raise SocketDeadlineExceededError("timed out") - if timeout is None or deadline_timeout <= timeout: - self._socket.settimeout(deadline_timeout) - try: - return func(*args, **kwargs) - except TimeoutError as e: - raise SocketDeadlineExceededError("timed out") from e - finally: - self._socket.settimeout(timeout) - return func(*args, **kwargs) + except TimeoutError as e: + if not rewrite_error: + raise + raise SocketDeadlineExceededError(f"{name} timed out") from e + + def get_read_deadline(self): + return self._read_deadline - def get_deadline(self): - return self._deadline + def set_read_deadline(self, deadline): + self._read_deadline = _sanitize_deadline(deadline) - def set_deadline(self, deadline): - self._deadline = _sanitize_deadline(deadline) + def get_write_deadline(self): + return self._write_deadline + + def set_write_deadline(self, deadline): + self._write_deadline = _sanitize_deadline(deadline) + + def get_read_timeout(self): + return self._read_timeout + + def set_read_timeout(self, timeout): + self._read_timeout = _sanitize_timeout(timeout) + + def get_write_timeout(self): + return self._write_timeout + + def set_write_timeout(self, timeout): + self._write_timeout = _sanitize_timeout(timeout) def recv(self, n): - return self._wait_for_io(self._socket.recv, n) + return self._wait_for_read(self._socket.recv, n) def recv_into(self, buffer, nbytes): - return self._wait_for_io(self._socket.recv_into, buffer, nbytes) + return self._wait_for_read(self._socket.recv_into, buffer, nbytes) def sendall(self, data): - return self._wait_for_io(self._socket.sendall, data) + return self._wait_for_write(self._socket.sendall, data) def close(self): self.close_socket(self._socket) @@ -412,6 +501,9 @@ def _connect_secure( s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) except TimeoutError: log.debug("[#0000] S: %s", resolved_address) + log.debug("[#0000] C: %s", resolved_address) + if s: + cls._kill_raw_socket(s) raise ServiceUnavailable( "Timed out trying to establish connection to " f"{resolved_address!r}" @@ -422,6 +514,9 @@ def _connect_secure( type(error).__name__, " ".join(map(repr, error.args)), ) + log.debug("[#0000] C: %s", resolved_address) + if s: + cls._kill_raw_socket(s) if isinstance(error, OSError): raise ServiceUnavailable( "Failed to establish connection to " @@ -438,6 +533,14 @@ def _connect_secure( try: s = ssl_context.wrap_socket(s, server_hostname=sni_host) except (OSError, SSLError, CertificateError) as cause: + log.debug( + "[#0000] S: %s: %s", + resolved_address, + cause, + ) + log.debug("[#0000] C: %s", resolved_address) + if s: + cls._kill_raw_socket(s) raise BoltSecurityError( message="Failed to establish encrypted connection.", address=(hostname, local_port), @@ -447,6 +550,13 @@ def _connect_secure( binary_form=True ) if der_encoded_server_certificate is None: + log.debug( + "[#0000] S: %s: no certificate", + resolved_address, + ) + log.debug("[#0000] C: %s", resolved_address) + if s: + cls._kill_raw_socket(s) raise BoltProtocolError( "When using an encrypted socket, the server should" "always provide a certificate", diff --git a/src/neo4j/_deadline.py b/src/neo4j/_deadline.py index 273d7a2eb..af195bdeb 100644 --- a/src/neo4j/_deadline.py +++ b/src/neo4j/_deadline.py @@ -84,16 +84,41 @@ def merge_deadlines_and_timeouts(*deadline): @contextmanager def connection_deadline(connection, deadline): - original_deadline = connection.socket.get_deadline() - if deadline is None and original_deadline is not None: + with ( + connection_read_deadline(connection, deadline), + connection_write_deadline(connection, deadline), + ): + yield + + +def connection_read_deadline(connection, deadline): + return _connection_deadline_wrapper( + deadline, + connection.socket.get_read_deadline, + connection.socket.set_read_deadline, + ) + + +def connection_write_deadline(connection, deadline): + return _connection_deadline_wrapper( + deadline, + connection.socket.get_write_deadline, + connection.socket.set_write_deadline, + ) + + +@contextmanager +def _connection_deadline_wrapper(deadline, deadline_getter, deadline_setter): + if deadline is None: # nothing to do here yield return + original_deadline = deadline_getter() deadline = merge_deadlines( d for d in (deadline, original_deadline) if d is not None ) - connection.socket.set_deadline(deadline) + deadline_setter(deadline) try: yield finally: - connection.socket.set_deadline(original_deadline) + deadline_setter(original_deadline) diff --git a/src/neo4j/_io/__init__.py b/src/neo4j/_io/__init__.py index d6fd674bd..dcd677bc0 100644 --- a/src/neo4j/_io/__init__.py +++ b/src/neo4j/_io/__init__.py @@ -16,6 +16,8 @@ from __future__ import annotations +import math + from .. import _typing as t # noqa: TC001 @@ -94,3 +96,15 @@ def __repr__(self) -> str: def __str__(self) -> str: return f"{self.major}.{self.minor}" + + +def min_timeout(*timeouts: float | None) -> float | None: + """Return the minimum timeout from an iterable of timeouts.""" + return min( + ( + to + for to in timeouts + if to is not None and not math.isnan(to) and to >= 0 + ), + default=None, + ) diff --git a/src/neo4j/_sync/config.py b/src/neo4j/_sync/config.py index e528ab074..48951a7db 100644 --- a/src/neo4j/_sync/config.py +++ b/src/neo4j/_sync/config.py @@ -54,6 +54,10 @@ class PoolConfig(Config): # The maximum amount of time to wait for a TCP connection to be # established. + #: Connection Write Timeout + connection_write_timeout = 30.0 # seconds + # The maximum amount of time to wait for I/O write operations to complete. + #: Custom Resolver resolver = None # Custom resolver function, returning list of resolved addresses. diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 36d7fb3c3..f0698d4dc 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -126,6 +126,7 @@ def driver( liveness_check_timeout: float | None = ..., max_connection_pool_size: int = ..., connection_timeout: float = ..., + connection_write_timeout: float = ..., resolver: ( t.Callable[[Address], t.Iterable[Address]] | t.Callable[[Address], t.Union[t.Iterable[Address]]] diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 55f9fbd08..134c09248 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -418,11 +418,13 @@ def open( ) try: - connection.socket.set_deadline(deadline) + connection.socket.set_read_deadline(deadline) + connection.socket.set_write_deadline(deadline) try: connection.hello() finally: - connection.socket.set_deadline(None) + connection.socket.set_read_deadline(None) + connection.socket.set_write_deadline(None) except ( Exception, # Python 3.8+: CancelledError is a subclass of BaseException diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 87213df18..543e3eebf 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -626,7 +626,7 @@ def on_success(metadata): "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: - self.socket.settimeout(recv_timeout) + self.socket.set_read_timeout(recv_timeout) else: log.info( "[#%04X] _: Server supplied an " diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 916c789e6..2468e754b 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -149,7 +149,7 @@ def on_success(metadata): "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: - self.socket.settimeout(recv_timeout) + self.socket.set_read_timeout(recv_timeout) else: log.info( "[#%04X] _: Server supplied an " @@ -622,7 +622,7 @@ def on_success(metadata): "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: - self.socket.settimeout(recv_timeout) + self.socket.set_read_timeout(recv_timeout) else: log.info( "[#%04X] _: Server supplied an " @@ -708,7 +708,7 @@ def on_success(metadata): "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: - self.socket.settimeout(recv_timeout) + self.socket.set_read_timeout(recv_timeout) else: log.info( "[#%04X] _: Server supplied an " diff --git a/src/neo4j/_sync/io/_bolt6.py b/src/neo4j/_sync/io/_bolt6.py index dc7c213bf..4d89c6666 100644 --- a/src/neo4j/_sync/io/_bolt6.py +++ b/src/neo4j/_sync/io/_bolt6.py @@ -158,7 +158,7 @@ def on_success(metadata): "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: - self.socket.settimeout(recv_timeout) + self.socket.set_read_timeout(recv_timeout) else: log.info( "[#%04X] _: Server supplied an " diff --git a/src/neo4j/_sync/io/_bolt_socket.py b/src/neo4j/_sync/io/_bolt_socket.py index 00ea9fa8a..44b18edc6 100644 --- a/src/neo4j/_sync/io/_bolt_socket.py +++ b/src/neo4j/_sync/io/_bolt_socket.py @@ -32,7 +32,10 @@ BoltError, BoltProtocolError, ) -from ..._io import BoltProtocolVersion +from ..._io import ( + BoltProtocolVersion, + min_timeout, +) from ...exceptions import ( DriverError, ServiceUnavailable, @@ -157,8 +160,8 @@ def _encode_varint(n: int) -> bytearray: return res def _handshake_read(self, ctx: HandshakeCtx, n: int) -> bytes: - original_timeout = self.gettimeout() - self.settimeout(ctx.deadline.to_timeout()) + original_timeout = self.get_read_timeout() + self.set_read_timeout(ctx.deadline.to_timeout()) try: response = self.recv(n) ctx.full_response.extend(response) @@ -168,7 +171,7 @@ def _handshake_read(self, ctx: HandshakeCtx, n: int) -> bytes: f"{ctx.resolved_address!r} (deadline {ctx.deadline})" ) from exc finally: - self.settimeout(original_timeout) + self.set_read_timeout(original_timeout) data_size = len(response) if data_size == 0: # If no data is returned after a successful select @@ -192,9 +195,11 @@ def _handshake_read(self, ctx: HandshakeCtx, n: int) -> bytes: return response - def _handshake_send(self, ctx, data): - original_timeout = self.gettimeout() - self.settimeout(ctx.deadline.to_timeout()) + def _handshake_send(self, ctx, data, write_timeout=None): + original_timeout = self.get_write_timeout() + self.set_write_timeout( + min_timeout(ctx.deadline.to_timeout(), write_timeout) + ) try: self.sendall(data) except OSError as exc: @@ -203,7 +208,7 @@ def _handshake_send(self, ctx, data): f"{ctx.resolved_address!r} (deadline {ctx.deadline})" ) from exc finally: - self.settimeout(original_timeout) + self.set_write_timeout(original_timeout) def _handshake( self, diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index c5d3b29c3..dc5186a8e 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -67,17 +67,23 @@ def close_side_effect(): mock.AsyncMock(side_effect=close_side_effect), "close" ) - self.socket.attach_mock( - mock.Mock(return_value=None), "get_deadline" - ) + for op in ("read", "write"): + self.socket.attach_mock( + mock.Mock(return_value=None), f"get_{op}_deadline" + ) - def set_deadline_side_effect(deadline): - deadline = Deadline.from_timeout_or_deadline(deadline) - self.socket.get_deadline.return_value = deadline + def make_set_deadline_side_effect(op_): + def side_effect(deadline): + deadline = Deadline.from_timeout_or_deadline(deadline) + get_mock = getattr(self.socket, f"get_{op_}_deadline") + get_mock.return_value = deadline - self.socket.attach_mock( - mock.Mock(side_effect=set_deadline_side_effect), "set_deadline" - ) + return side_effect + + self.socket.attach_mock( + mock.Mock(side_effect=make_set_deadline_side_effect(op)), + f"set_{op}_deadline", + ) @property def is_reset(self): diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index 509b6c76e..6977e2434 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -146,7 +146,8 @@ async def test_hint_recv_timeout_seconds_gets_ignored( sockets = fake_socket_pair( address, AsyncBolt3.PACKER_CLS, AsyncBolt3.UNPACKER_CLS ) - sockets.client.settimeout = mocker.AsyncMock() + sockets.client.set_read_timeout = mocker.AsyncMock() + sockets.client.set_write_timeout = mocker.AsyncMock() await sockets.server.send_message( b"\x70", { @@ -158,7 +159,8 @@ async def test_hint_recv_timeout_seconds_gets_ignored( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) await connection.hello() - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() + sockets.client.set_write_timeout.assert_not_called() CREDENTIALS = "+++super-secret-sauce+++" @@ -185,7 +187,6 @@ async def test_credentials_are_not_logged( packer_cls=AsyncBolt3.PACKER_CLS, unpacker_cls=AsyncBolt3.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt3( address, diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index e3af55ad0..ae37da337 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -251,7 +251,8 @@ async def test_hint_recv_timeout_seconds_gets_ignored( packer_cls=AsyncBolt4x0.PACKER_CLS, unpacker_cls=AsyncBolt4x0.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.MagicMock() + sockets.client.set_read_timeout = mocker.MagicMock() + sockets.client.set_write_timeout = mocker.MagicMock() await sockets.server.send_message( b"\x70", { @@ -263,7 +264,8 @@ async def test_hint_recv_timeout_seconds_gets_ignored( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) await connection.hello() - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() + sockets.client.set_write_timeout.assert_not_called() CREDENTIALS = "+++super-secret-sauce+++" @@ -290,7 +292,6 @@ async def test_credentials_are_not_logged( packer_cls=AsyncBolt4x0.PACKER_CLS, unpacker_cls=AsyncBolt4x0.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt4x0( address, diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index 7f33fe011..45f56d16f 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -273,7 +273,8 @@ async def test_hint_recv_timeout_seconds_gets_ignored( packer_cls=AsyncBolt4x1.PACKER_CLS, unpacker_cls=AsyncBolt4x1.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.AsyncMock() + sockets.client.set_read_timeout = mocker.AsyncMock() + sockets.client.set_write_timeout = mocker.AsyncMock() await sockets.server.send_message( b"\x70", { @@ -285,7 +286,8 @@ async def test_hint_recv_timeout_seconds_gets_ignored( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) await connection.hello() - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() + sockets.client.set_write_timeout.assert_not_called() CREDENTIALS = "+++super-secret-sauce+++" @@ -312,7 +314,6 @@ async def test_credentials_are_not_logged( packer_cls=AsyncBolt4x1.PACKER_CLS, unpacker_cls=AsyncBolt4x1.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt4x1( address, diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index d243aef6d..52a17b0e4 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -273,7 +273,8 @@ async def test_hint_recv_timeout_seconds_gets_ignored( packer_cls=AsyncBolt4x2.PACKER_CLS, unpacker_cls=AsyncBolt4x2.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.AsyncMock() + sockets.client.set_read_timeout = mocker.AsyncMock() + sockets.client.set_write_timeout = mocker.AsyncMock() await sockets.server.send_message( b"\x70", { @@ -285,7 +286,8 @@ async def test_hint_recv_timeout_seconds_gets_ignored( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) await connection.hello() - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() + sockets.client.set_write_timeout.assert_not_called() CREDENTIALS = "+++super-secret-sauce+++" @@ -312,7 +314,6 @@ async def test_credentials_are_not_logged( packer_cls=AsyncBolt4x2.PACKER_CLS, unpacker_cls=AsyncBolt4x2.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt4x2( address, diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index e1a23e01e..072b4bc6e 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -287,7 +287,8 @@ async def test_hint_recv_timeout_seconds( packer_cls=AsyncBolt4x3.PACKER_CLS, unpacker_cls=AsyncBolt4x3.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.0", "hints": hints} ) @@ -296,19 +297,20 @@ async def test_hint_recv_timeout_seconds( ) with caplog.at_level(logging.INFO): await connection.hello() + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg @@ -341,7 +343,6 @@ async def test_credentials_are_not_logged( packer_cls=AsyncBolt4x3.PACKER_CLS, unpacker_cls=AsyncBolt4x3.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt4x3( address, diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index c24f06a90..fbc1cf95f 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -311,28 +311,32 @@ async def test_hint_recv_timeout_seconds( packer_cls=AsyncBolt4x4.PACKER_CLS, unpacker_cls=AsyncBolt4x4.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) connection = AsyncBolt4x4( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): await connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg @@ -365,7 +369,6 @@ async def test_credentials_are_not_logged( packer_cls=AsyncBolt4x4.PACKER_CLS, unpacker_cls=AsyncBolt4x4.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt4x4( address, diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index 8a5de7154..e7e917e52 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -311,28 +311,32 @@ async def test_hint_recv_timeout_seconds( packer_cls=AsyncBolt5x0.PACKER_CLS, unpacker_cls=AsyncBolt5x0.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) connection = AsyncBolt5x0( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): await connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg @@ -365,7 +369,6 @@ async def test_credentials_are_not_logged( packer_cls=AsyncBolt5x0.PACKER_CLS, unpacker_cls=AsyncBolt5x0.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = AsyncBolt5x0( address, diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 847ff059a..17774d40d 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -429,7 +429,8 @@ async def test_hint_recv_timeout_seconds( packer_cls=AsyncBolt5x1.PACKER_CLS, unpacker_cls=AsyncBolt5x1.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -437,21 +438,24 @@ async def test_hint_recv_timeout_seconds( connection = AsyncBolt5x1( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): await connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/async_/io/test_class_bolt5x2.py b/tests/unit/async_/io/test_class_bolt5x2.py index c4c08edae..db140462d 100644 --- a/tests/unit/async_/io/test_class_bolt5x2.py +++ b/tests/unit/async_/io/test_class_bolt5x2.py @@ -427,7 +427,8 @@ async def test_hint_recv_timeout_seconds( packer_cls=AsyncBolt5x2.PACKER_CLS, unpacker_cls=AsyncBolt5x2.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -435,21 +436,24 @@ async def test_hint_recv_timeout_seconds( connection = AsyncBolt5x2( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): await connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/async_/io/test_class_bolt5x3.py b/tests/unit/async_/io/test_class_bolt5x3.py index e3a765632..018499c94 100644 --- a/tests/unit/async_/io/test_class_bolt5x3.py +++ b/tests/unit/async_/io/test_class_bolt5x3.py @@ -314,7 +314,8 @@ async def test_hint_recv_timeout_seconds( packer_cls=AsyncBolt5x3.PACKER_CLS, unpacker_cls=AsyncBolt5x3.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -322,21 +323,24 @@ async def test_hint_recv_timeout_seconds( connection = AsyncBolt5x3( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): await connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/async_/io/test_class_bolt5x4.py b/tests/unit/async_/io/test_class_bolt5x4.py index 48e741147..aa3047b85 100644 --- a/tests/unit/async_/io/test_class_bolt5x4.py +++ b/tests/unit/async_/io/test_class_bolt5x4.py @@ -319,7 +319,8 @@ async def test_hint_recv_timeout_seconds( packer_cls=AsyncBolt5x4.PACKER_CLS, unpacker_cls=AsyncBolt5x4.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -327,21 +328,24 @@ async def test_hint_recv_timeout_seconds( connection = AsyncBolt5x4( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): await connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/async_/io/test_class_bolt5x5.py b/tests/unit/async_/io/test_class_bolt5x5.py index 60ea25ee6..33f7a0d29 100644 --- a/tests/unit/async_/io/test_class_bolt5x5.py +++ b/tests/unit/async_/io/test_class_bolt5x5.py @@ -319,7 +319,8 @@ async def test_hint_recv_timeout_seconds( packer_cls=AsyncBolt5x5.PACKER_CLS, unpacker_cls=AsyncBolt5x5.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -327,21 +328,24 @@ async def test_hint_recv_timeout_seconds( connection = AsyncBolt5x5( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): await connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/async_/io/test_class_bolt5x6.py b/tests/unit/async_/io/test_class_bolt5x6.py index 1f11cf753..b118d638d 100644 --- a/tests/unit/async_/io/test_class_bolt5x6.py +++ b/tests/unit/async_/io/test_class_bolt5x6.py @@ -319,7 +319,8 @@ async def test_hint_recv_timeout_seconds( packer_cls=AsyncBolt5x6.PACKER_CLS, unpacker_cls=AsyncBolt5x6.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -327,21 +328,24 @@ async def test_hint_recv_timeout_seconds( connection = AsyncBolt5x6( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): await connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/async_/io/test_class_bolt5x7.py b/tests/unit/async_/io/test_class_bolt5x7.py index 09752de14..cfa790438 100644 --- a/tests/unit/async_/io/test_class_bolt5x7.py +++ b/tests/unit/async_/io/test_class_bolt5x7.py @@ -320,7 +320,8 @@ async def test_hint_recv_timeout_seconds( packer_cls=AsyncBolt5x7.PACKER_CLS, unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -328,21 +329,24 @@ async def test_hint_recv_timeout_seconds( connection = AsyncBolt5x7( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): await connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/async_/io/test_class_bolt5x8.py b/tests/unit/async_/io/test_class_bolt5x8.py index 6a105d1eb..bc9cd9244 100644 --- a/tests/unit/async_/io/test_class_bolt5x8.py +++ b/tests/unit/async_/io/test_class_bolt5x8.py @@ -320,7 +320,8 @@ async def test_hint_recv_timeout_seconds( packer_cls=AsyncBolt5x8.PACKER_CLS, unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -328,21 +329,24 @@ async def test_hint_recv_timeout_seconds( connection = AsyncBolt5x8( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): await connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/async_/io/test_class_bolt6x0.py b/tests/unit/async_/io/test_class_bolt6x0.py index ad6777cc0..ef7e2328f 100644 --- a/tests/unit/async_/io/test_class_bolt6x0.py +++ b/tests/unit/async_/io/test_class_bolt6x0.py @@ -320,7 +320,8 @@ async def test_hint_recv_timeout_seconds( packer_cls=AsyncBolt6x0.PACKER_CLS, unpacker_cls=AsyncBolt6x0.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -328,21 +329,24 @@ async def test_hint_recv_timeout_seconds( connection = AsyncBolt6x0( address, sockets.client, AsyncPoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): await connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/async_/test_conf.py b/tests/unit/async_/test_conf.py index 0dafe026f..d80473ca9 100644 --- a/tests/unit/async_/test_conf.py +++ b/tests/unit/async_/test_conf.py @@ -45,6 +45,7 @@ test_pool_config = { "connection_timeout": 30.0, + "connection_write_timeout": 30.0, "keep_alive": True, "max_connection_lifetime": 3600, "liveness_check_timeout": None, diff --git a/tests/unit/fixtures/socket.py b/tests/unit/fixtures/socket.py index 91c8221b5..5c5bf8ff8 100644 --- a/tests/unit/fixtures/socket.py +++ b/tests/unit/fixtures/socket.py @@ -122,7 +122,6 @@ def send_all(b): socket_mock.sendall.side_effect = send_all socket_mock.getsockname.return_value = ("localhost", 0x1234) socket_mock.getpeername.return_value = "peer_name" - socket_mock.gettimeout.return_value = None return BoltSocket(socket_mock) diff --git a/tests/unit/mixed/async_compat/test_network.py b/tests/unit/mixed/async_compat/test_network.py index 670c13ae2..75b9208db 100644 --- a/tests/unit/mixed/async_compat/test_network.py +++ b/tests/unit/mixed/async_compat/test_network.py @@ -25,6 +25,7 @@ from neo4j import _typing as t from neo4j._async.io._bolt_socket import AsyncBoltSocket from neo4j._exceptions import SocketDeadlineExceededError +from neo4j._sync.io._bolt_socket import BoltSocket from ...._async_compat.mark_decorator import mark_async_test @@ -100,10 +101,17 @@ def writer(s: AsyncBoltSocket): (5, 4, 0, 6, SocketDeadlineExceededError), ), ) -@pytest.mark.parametrize("method", ("recv", "recv_into", "sendall")) +@pytest.mark.parametrize( + ("method", "op"), + ( + ("recv", "read"), + ("recv_into", "read"), + ("sendall", "write"), + ), +) @mark_async_test -async def test_async_bolt_socket_read_timeout( - socket_factory, timeout, deadline, pre_tick, tick, exception, method +async def test_async_bolt_socket_timeout( + socket_factory, timeout, deadline, pre_tick, tick, exception, method, op ): def make_read_side_effect(freeze_time: TFreezeTime): async def read_side_effect(n): @@ -139,9 +147,9 @@ async def call_method(s: AsyncBoltSocket): with freezegun.freeze_time("1970-01-01T00:00:00") as frozen_time: socket = socket_factory() if timeout is not None: - socket.settimeout(timeout) + getattr(socket, f"set_{op}_timeout")(timeout) if deadline is not None: - socket.set_deadline(deadline) + getattr(socket, f"set_{op}_deadline")(deadline) if pre_tick: frozen_time.tick(pre_tick) @@ -161,3 +169,57 @@ async def call_method(s: AsyncBoltSocket): await call_method(socket) else: await call_method(socket) + + +@pytest.mark.parametrize( + ("timeout", "deadline", "expected_timeout"), + ( + (None, None, None), + (5, None, 5), + (1.23, None, 1.23), + (None, 5, 5), + (None, 1.23, 1.23), + (1, 2, 1), + (2, 1, 1), + (1.2, 2, 1.2), + (2, 1.2, 1.2), + (1, 2.3, 1), + (2.3, 1, 1), + ), +) +@pytest.mark.parametrize( + ("method", "op"), + ( + ("recv", "read"), + ("recv_into", "read"), + ("sendall", "write"), + ), +) +def test_bolt_socket_timeout_forwarding( + timeout, deadline, expected_timeout, method, op, mocker +): + def call_method(s: BoltSocket): + if method == "recv": + s.recv(1) + elif method == "recv_into": + b = bytearray(1) + s.recv_into(b, 1) + elif method == "sendall": + s.sendall(b"y") + else: + raise NotImplementedError(f"method: {method}") + + socket_mock = mocker.Mock(spec=socket.socket) + bolt_socket = BoltSocket(socket_mock) + + with freezegun.freeze_time("1970-01-01T00:00:00"): + if timeout is not None: + getattr(bolt_socket, f"set_{op}_timeout")(timeout) + if deadline is not None: + getattr(bolt_socket, f"set_{op}_deadline")(deadline) + + socket_mock.settimeout.assert_not_called() + + call_method(bolt_socket) + + socket_mock.settimeout.assert_called_once_with(expected_timeout) diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index fc60a4a28..d03dde69d 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -67,17 +67,23 @@ def close_side_effect(): mock.MagicMock(side_effect=close_side_effect), "close" ) - self.socket.attach_mock( - mock.Mock(return_value=None), "get_deadline" - ) + for op in ("read", "write"): + self.socket.attach_mock( + mock.Mock(return_value=None), f"get_{op}_deadline" + ) - def set_deadline_side_effect(deadline): - deadline = Deadline.from_timeout_or_deadline(deadline) - self.socket.get_deadline.return_value = deadline + def make_set_deadline_side_effect(op_): + def side_effect(deadline): + deadline = Deadline.from_timeout_or_deadline(deadline) + get_mock = getattr(self.socket, f"get_{op_}_deadline") + get_mock.return_value = deadline - self.socket.attach_mock( - mock.Mock(side_effect=set_deadline_side_effect), "set_deadline" - ) + return side_effect + + self.socket.attach_mock( + mock.Mock(side_effect=make_set_deadline_side_effect(op)), + f"set_{op}_deadline", + ) @property def is_reset(self): diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index af3d3c6b7..47a33609d 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -146,7 +146,8 @@ def test_hint_recv_timeout_seconds_gets_ignored( sockets = fake_socket_pair( address, Bolt3.PACKER_CLS, Bolt3.UNPACKER_CLS ) - sockets.client.settimeout = mocker.MagicMock() + sockets.client.set_read_timeout = mocker.MagicMock() + sockets.client.set_write_timeout = mocker.MagicMock() sockets.server.send_message( b"\x70", { @@ -158,7 +159,8 @@ def test_hint_recv_timeout_seconds_gets_ignored( address, sockets.client, PoolConfig.max_connection_lifetime ) connection.hello() - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() + sockets.client.set_write_timeout.assert_not_called() CREDENTIALS = "+++super-secret-sauce+++" @@ -185,7 +187,6 @@ def test_credentials_are_not_logged( packer_cls=Bolt3.PACKER_CLS, unpacker_cls=Bolt3.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = Bolt3( address, diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index f9dfef4b2..95f132eb7 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -251,7 +251,8 @@ def test_hint_recv_timeout_seconds_gets_ignored( packer_cls=Bolt4x0.PACKER_CLS, unpacker_cls=Bolt4x0.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.MagicMock() + sockets.client.set_read_timeout = mocker.MagicMock() + sockets.client.set_write_timeout = mocker.MagicMock() sockets.server.send_message( b"\x70", { @@ -263,7 +264,8 @@ def test_hint_recv_timeout_seconds_gets_ignored( address, sockets.client, PoolConfig.max_connection_lifetime ) connection.hello() - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() + sockets.client.set_write_timeout.assert_not_called() CREDENTIALS = "+++super-secret-sauce+++" @@ -290,7 +292,6 @@ def test_credentials_are_not_logged( packer_cls=Bolt4x0.PACKER_CLS, unpacker_cls=Bolt4x0.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = Bolt4x0( address, diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index 219a9fdaf..7a97daa9b 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -273,7 +273,8 @@ def test_hint_recv_timeout_seconds_gets_ignored( packer_cls=Bolt4x1.PACKER_CLS, unpacker_cls=Bolt4x1.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.MagicMock() + sockets.client.set_read_timeout = mocker.MagicMock() + sockets.client.set_write_timeout = mocker.MagicMock() sockets.server.send_message( b"\x70", { @@ -285,7 +286,8 @@ def test_hint_recv_timeout_seconds_gets_ignored( address, sockets.client, PoolConfig.max_connection_lifetime ) connection.hello() - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() + sockets.client.set_write_timeout.assert_not_called() CREDENTIALS = "+++super-secret-sauce+++" @@ -312,7 +314,6 @@ def test_credentials_are_not_logged( packer_cls=Bolt4x1.PACKER_CLS, unpacker_cls=Bolt4x1.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = Bolt4x1( address, diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index 944f7c280..f9ee311d3 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -273,7 +273,8 @@ def test_hint_recv_timeout_seconds_gets_ignored( packer_cls=Bolt4x2.PACKER_CLS, unpacker_cls=Bolt4x2.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.MagicMock() + sockets.client.set_read_timeout = mocker.MagicMock() + sockets.client.set_write_timeout = mocker.MagicMock() sockets.server.send_message( b"\x70", { @@ -285,7 +286,8 @@ def test_hint_recv_timeout_seconds_gets_ignored( address, sockets.client, PoolConfig.max_connection_lifetime ) connection.hello() - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() + sockets.client.set_write_timeout.assert_not_called() CREDENTIALS = "+++super-secret-sauce+++" @@ -312,7 +314,6 @@ def test_credentials_are_not_logged( packer_cls=Bolt4x2.PACKER_CLS, unpacker_cls=Bolt4x2.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = Bolt4x2( address, diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index 2e53fc42d..a4572bf5d 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -287,7 +287,8 @@ def test_hint_recv_timeout_seconds( packer_cls=Bolt4x3.PACKER_CLS, unpacker_cls=Bolt4x3.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.0", "hints": hints} ) @@ -296,19 +297,20 @@ def test_hint_recv_timeout_seconds( ) with caplog.at_level(logging.INFO): connection.hello() + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg @@ -341,7 +343,6 @@ def test_credentials_are_not_logged( packer_cls=Bolt4x3.PACKER_CLS, unpacker_cls=Bolt4x3.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = Bolt4x3( address, diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 9378e8bf1..f772ab247 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -311,28 +311,32 @@ def test_hint_recv_timeout_seconds( packer_cls=Bolt4x4.PACKER_CLS, unpacker_cls=Bolt4x4.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) connection = Bolt4x4( address, sockets.client, PoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg @@ -365,7 +369,6 @@ def test_credentials_are_not_logged( packer_cls=Bolt4x4.PACKER_CLS, unpacker_cls=Bolt4x4.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = Bolt4x4( address, diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py index 2390112d6..4a59888dc 100644 --- a/tests/unit/sync/io/test_class_bolt5x0.py +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -311,28 +311,32 @@ def test_hint_recv_timeout_seconds( packer_cls=Bolt5x0.PACKER_CLS, unpacker_cls=Bolt5x0.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) connection = Bolt5x0( address, sockets.client, PoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg @@ -365,7 +369,6 @@ def test_credentials_are_not_logged( packer_cls=Bolt5x0.PACKER_CLS, unpacker_cls=Bolt5x0.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) connection = Bolt5x0( address, diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index 7b2804c4a..befbf0314 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -429,7 +429,8 @@ def test_hint_recv_timeout_seconds( packer_cls=Bolt5x1.PACKER_CLS, unpacker_cls=Bolt5x1.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -437,21 +438,24 @@ def test_hint_recv_timeout_seconds( connection = Bolt5x1( address, sockets.client, PoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/sync/io/test_class_bolt5x2.py b/tests/unit/sync/io/test_class_bolt5x2.py index 165d17763..e4557b41a 100644 --- a/tests/unit/sync/io/test_class_bolt5x2.py +++ b/tests/unit/sync/io/test_class_bolt5x2.py @@ -427,7 +427,8 @@ def test_hint_recv_timeout_seconds( packer_cls=Bolt5x2.PACKER_CLS, unpacker_cls=Bolt5x2.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -435,21 +436,24 @@ def test_hint_recv_timeout_seconds( connection = Bolt5x2( address, sockets.client, PoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/sync/io/test_class_bolt5x3.py b/tests/unit/sync/io/test_class_bolt5x3.py index d0d171314..ca40dfb7a 100644 --- a/tests/unit/sync/io/test_class_bolt5x3.py +++ b/tests/unit/sync/io/test_class_bolt5x3.py @@ -314,7 +314,8 @@ def test_hint_recv_timeout_seconds( packer_cls=Bolt5x3.PACKER_CLS, unpacker_cls=Bolt5x3.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -322,21 +323,24 @@ def test_hint_recv_timeout_seconds( connection = Bolt5x3( address, sockets.client, PoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/sync/io/test_class_bolt5x4.py b/tests/unit/sync/io/test_class_bolt5x4.py index ea938cc2a..58517018d 100644 --- a/tests/unit/sync/io/test_class_bolt5x4.py +++ b/tests/unit/sync/io/test_class_bolt5x4.py @@ -319,7 +319,8 @@ def test_hint_recv_timeout_seconds( packer_cls=Bolt5x4.PACKER_CLS, unpacker_cls=Bolt5x4.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -327,21 +328,24 @@ def test_hint_recv_timeout_seconds( connection = Bolt5x4( address, sockets.client, PoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/sync/io/test_class_bolt5x5.py b/tests/unit/sync/io/test_class_bolt5x5.py index e5cc6e744..56a139173 100644 --- a/tests/unit/sync/io/test_class_bolt5x5.py +++ b/tests/unit/sync/io/test_class_bolt5x5.py @@ -319,7 +319,8 @@ def test_hint_recv_timeout_seconds( packer_cls=Bolt5x5.PACKER_CLS, unpacker_cls=Bolt5x5.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -327,21 +328,24 @@ def test_hint_recv_timeout_seconds( connection = Bolt5x5( address, sockets.client, PoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/sync/io/test_class_bolt5x6.py b/tests/unit/sync/io/test_class_bolt5x6.py index a472ef5f2..e4f5488d9 100644 --- a/tests/unit/sync/io/test_class_bolt5x6.py +++ b/tests/unit/sync/io/test_class_bolt5x6.py @@ -319,7 +319,8 @@ def test_hint_recv_timeout_seconds( packer_cls=Bolt5x6.PACKER_CLS, unpacker_cls=Bolt5x6.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -327,21 +328,24 @@ def test_hint_recv_timeout_seconds( connection = Bolt5x6( address, sockets.client, PoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/sync/io/test_class_bolt5x7.py b/tests/unit/sync/io/test_class_bolt5x7.py index 95890c79a..e7f7532c2 100644 --- a/tests/unit/sync/io/test_class_bolt5x7.py +++ b/tests/unit/sync/io/test_class_bolt5x7.py @@ -320,7 +320,8 @@ def test_hint_recv_timeout_seconds( packer_cls=Bolt5x7.PACKER_CLS, unpacker_cls=Bolt5x7.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -328,21 +329,24 @@ def test_hint_recv_timeout_seconds( connection = Bolt5x7( address, sockets.client, PoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/sync/io/test_class_bolt5x8.py b/tests/unit/sync/io/test_class_bolt5x8.py index 25de8c2b8..79ac5ba29 100644 --- a/tests/unit/sync/io/test_class_bolt5x8.py +++ b/tests/unit/sync/io/test_class_bolt5x8.py @@ -320,7 +320,8 @@ def test_hint_recv_timeout_seconds( packer_cls=Bolt5x8.PACKER_CLS, unpacker_cls=Bolt5x8.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -328,21 +329,24 @@ def test_hint_recv_timeout_seconds( connection = Bolt5x8( address, sockets.client, PoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/sync/io/test_class_bolt6x0.py b/tests/unit/sync/io/test_class_bolt6x0.py index eb8b3867c..443fc463e 100644 --- a/tests/unit/sync/io/test_class_bolt6x0.py +++ b/tests/unit/sync/io/test_class_bolt6x0.py @@ -320,7 +320,8 @@ def test_hint_recv_timeout_seconds( packer_cls=Bolt6x0.PACKER_CLS, unpacker_cls=Bolt6x0.UNPACKER_CLS, ) - sockets.client.settimeout = mocker.Mock() + sockets.client.set_read_timeout = mocker.Mock() + sockets.client.set_write_timeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) @@ -328,21 +329,24 @@ def test_hint_recv_timeout_seconds( connection = Bolt6x0( address, sockets.client, PoolConfig.max_connection_lifetime ) + with caplog.at_level(logging.INFO): connection.hello() + + sockets.client.set_write_timeout.assert_not_called() if valid: if "connection.recv_timeout_seconds" in hints: - sockets.client.settimeout.assert_called_once_with( + sockets.client.set_read_timeout.assert_called_once_with( hints["connection.recv_timeout_seconds"] ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert not any( "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages ) else: - sockets.client.settimeout.assert_not_called() + sockets.client.set_read_timeout.assert_not_called() assert any( repr(hints["connection.recv_timeout_seconds"]) in msg and "recv_timeout_seconds" in msg diff --git a/tests/unit/sync/test_conf.py b/tests/unit/sync/test_conf.py index 7d8ec73a5..12d0b299a 100644 --- a/tests/unit/sync/test_conf.py +++ b/tests/unit/sync/test_conf.py @@ -45,6 +45,7 @@ test_pool_config = { "connection_timeout": 30.0, + "connection_write_timeout": 30.0, "keep_alive": True, "max_connection_lifetime": 3600, "liveness_check_timeout": None,