diff --git a/docs/source/api.rst b/docs/source/api.rst index 20219c193..a84387638 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -399,6 +399,7 @@ Additional configuration can be provided via the :class:`neo4j.Driver` construct + :ref:`user-agent-ref` + :ref:`driver-notifications-min-severity-ref` + :ref:`driver-notifications-disabled-categories-ref` ++ :ref:`telemetry-disabled-ref` .. _connection-acquisition-timeout-ref: @@ -664,6 +665,30 @@ Notifications are available via :attr:`.ResultSummary.notifications` and :attr:` .. seealso:: :class:`.NotificationDisabledCategory`, session config :ref:`session-notifications-disabled-categories-ref` +.. _telemetry-disabled-ref: + +``telemetry_disabled`` +---------------------- +By default, the driver will send anonymous usage statistics to the server it connects to if the server requests those. +By setting ``telemetry_disabled=True``, the driver will not send any telemetry data. + +The driver transmits the following information: + +* Every time one of the following APIs is used to execute a query (for the first time), the server is informed of this + (without any further information like arguments, client identifiers, etc.): + + * :meth:`.Driver.execute_query` + * :meth:`.Session.begin_transaction` + * :meth:`.Session.execute_read`, :meth:`.Session.execute_write` + * :meth:`.Session.run` + * the async counterparts of the above methods + +:Type: :class:`bool` +:Default: :data:`False` + +.. versionadded:: 5.13 + + Driver Object Lifetime ====================== diff --git a/src/neo4j/_api.py b/src/neo4j/_api.py index 27945c98d..a53cdff43 100644 --- a/src/neo4j/_api.py +++ b/src/neo4j/_api.py @@ -32,6 +32,7 @@ "NotificationCategory", "NotificationSeverity", "RoutingControl", + "TelemetryAPI" ] @@ -227,6 +228,13 @@ class RoutingControl(str, Enum): WRITE = "w" +class TelemetryAPI(int, Enum): + TX_FUNC = 0 + TX = 1 + AUTO_COMMIT = 2 + DRIVER = 3 + + if t.TYPE_CHECKING: T_RoutingControl = t.Union[ RoutingControl, diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index cc923b553..de5e8bb28 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -31,7 +31,10 @@ T_NotificationMinimumSeverity, ) -from .._api import RoutingControl +from .._api import ( + RoutingControl, + TelemetryAPI, +) from .._async_compat.util import AsyncUtil from .._conf import ( Config, @@ -71,6 +74,7 @@ URI_SCHEME_NEO4J, URI_SCHEME_NEO4J_SECURE, URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, + WRITE_ACCESS, ) from ..auth_management import ( AsyncAuthManager, @@ -159,7 +163,8 @@ def driver( fetch_size: int = ..., impersonated_user: t.Optional[str] = ..., bookmark_manager: t.Union[AsyncBookmarkManager, - BookmarkManager, None] = ... + BookmarkManager, None] = ..., + telemetry_disabled: bool = ..., ) -> AsyncDriver: ... @@ -866,15 +871,16 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record:: session = self._session(session_config) async with session: if routing_ == RoutingControl.WRITE: - executor = session.execute_write + access_mode = WRITE_ACCESS elif routing_ == RoutingControl.READ: - executor = session.execute_read + access_mode = READ_ACCESS else: raise ValueError("Invalid routing control value: %r" % routing_) with session._pipelined_begin: - return await executor( - _work, query_, parameters, result_transformer_ + return await session._run_transaction( + access_mode, TelemetryAPI.DRIVER, + _work, (query_, parameters, result_transformer_), {} ) @property diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index f39fa1c83..afff0f257 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -25,6 +25,7 @@ from logging import getLogger from time import perf_counter +from ..._api import TelemetryAPI from ..._async_compat.network import AsyncBoltSocket from ..._async_compat.util import AsyncUtil from ..._codec.hydration import v1 as hydration_v1 @@ -134,7 +135,8 @@ class AsyncBolt: def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, auth_manager=None, user_agent=None, routing_context=None, notifications_min_severity=None, - notifications_disabled_categories=None): + notifications_disabled_categories=None, + telemetry_disabled=False): self.unresolved_address = unresolved_address self.socket = sock self.local_port = self.socket.getsockname()[1] @@ -172,6 +174,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, self.auth = auth self.auth_dict = self._to_auth_dict(auth) self.auth_manager = auth_manager + self.telemetry_disabled = telemetry_disabled self.notifications_min_severity = notifications_min_severity self.notifications_disabled_categories = \ @@ -280,6 +283,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x1, AsyncBolt5x2, AsyncBolt5x3, + AsyncBolt5x4, ) handlers = { @@ -293,6 +297,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x1.PROTOCOL_VERSION: AsyncBolt5x1, AsyncBolt5x2.PROTOCOL_VERSION: AsyncBolt5x2, AsyncBolt5x3.PROTOCOL_VERSION: AsyncBolt5x3, + AsyncBolt5x4.PROTOCOL_VERSION: AsyncBolt5x4, } if protocol_version is None: @@ -407,7 +412,10 @@ async def open( # Carry out Bolt subclass imports locally to avoid circular dependency # issues. - if protocol_version == (5, 3): + if protocol_version == (5, 4): + from ._bolt5 import AsyncBolt5x4 + bolt_cls = AsyncBolt5x4 + elif protocol_version == (5, 3): from ._bolt5 import AsyncBolt5x3 bolt_cls = AsyncBolt5x3 elif protocol_version == (5, 2): @@ -471,7 +479,8 @@ async def open( routing_context=routing_context, notifications_min_severity=pool_config.notifications_min_severity, notifications_disabled_categories= - pool_config.notifications_disabled_categories + pool_config.notifications_disabled_categories, + telemetry_disabled=pool_config.telemetry_disabled, ) try: @@ -555,7 +564,6 @@ def re_auth( hydration_hooks=hydration_hooks) return True - @abc.abstractmethod async def route( self, database=None, imp_user=None, bookmarks=None, @@ -584,6 +592,23 @@ async def route( """ pass + @abc.abstractmethod + def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, + hydration_hooks=None, **handlers) -> None: + """Send telemetry information about the API usage to the server. + + :param api: the API used. + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. + """ + pass + @abc.abstractmethod def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, imp_user=None, diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index d272c30d2..9e4bac484 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -23,6 +23,7 @@ from logging import getLogger from ssl import SSLSocket +from ..._api import TelemetryAPI from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, @@ -225,6 +226,11 @@ def logoff(self, dehydration_hooks=None, hydration_hooks=None): """Append a LOGOFF message to the outgoing queue.""" self.assert_re_auth_support() + def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, + hydration_hooks=None, **handlers) -> None: + # TELEMETRY not support by this protocol version, so we ignore it. + pass + async def route( self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 149fd786b..e284ba6f3 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -19,6 +19,7 @@ from logging import getLogger from ssl import SSLSocket +from ..._api import TelemetryAPI from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, @@ -146,6 +147,11 @@ def logoff(self, dehydration_hooks=None, hydration_hooks=None): """Append a LOGOFF message to the outgoing queue.""" self.assert_re_auth_support() + def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, + hydration_hooks=None, **handlers) -> None: + # TELEMETRY not support by this protocol version, so we ignore it. + pass + async def route( self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index d9b6071ec..c5a2cafdb 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -21,6 +21,7 @@ from logging import getLogger from ssl import SSLSocket +from ..._api import TelemetryAPI from ..._codec.hydration import v2 as hydration_v2 from ..._exceptions import BoltProtocolError from ..._meta import BOLT_AGENT_DICT @@ -168,6 +169,11 @@ def logoff(self, dehydration_hooks=None, hydration_hooks=None): """Append a LOGOFF message to the outgoing queue.""" self.assert_re_auth_support() + def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, + hydration_hooks=None, **handlers) -> None: + # TELEMETRY not support by this protocol version, so we ignore it. + pass + async def route(self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None): routing_context = self.routing_context or {} @@ -665,3 +671,22 @@ def get_base_headers(self): headers = super().get_base_headers() headers["bolt_agent"] = BOLT_AGENT_DICT return headers + + +class AsyncBolt5x4(AsyncBolt5x3): + + PROTOCOL_VERSION = Version(5, 4) + + def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, + hydration_hooks=None, **handlers) -> None: + if ( + self.telemetry_disabled + or not self.configuration_hints.get("telemetry.enabled", False) + ): + return + api_raw = int(api) + log.debug("[#%04X] C: TELEMETRY %i # (%r)", + self.local_port, api_raw, api) + self._append(b"\x54", (api_raw,), + Response(self, "telemetry", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index f01cce16f..4cada63b5 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -24,6 +24,7 @@ from random import random from time import perf_counter +from ..._api import TelemetryAPI from ..._async_compat import async_sleep from ..._async_compat.util import AsyncUtil from ..._conf import SessionConfig @@ -300,8 +301,10 @@ async def run( if not self._connection: await self._connect(self._config.default_access_mode) + assert self._connection is not None cx = self._connection + cx.telemetry(TelemetryAPI.AUTO_COMMIT) self._auto_result = AsyncResult( cx, self._config.fetch_size, self._result_closed, self._result_error @@ -406,9 +409,20 @@ def _transaction_cancel_handler(self): ) async def _open_transaction( - self, *, tx_cls, access_mode, metadata=None, timeout=None - ): + self, + *, + tx_cls: t.Callable[ + ..., t.Union[AsyncTransaction, AsyncManagedTransaction] + ], + access_mode, api: t.Optional[TelemetryAPI], + metadata=None, + timeout=None, + api_success_cb: t.Optional[t.Callable[[dict], None]] = None, + ) -> None: await self._connect(access_mode=access_mode) + assert self._connection is not None + if api is not None: + self._connection.telemetry(api, on_success=api_success_cb) self._transaction = tx_cls( self._connection, self._config.fetch_size, self._transaction_closed_handler, @@ -474,21 +488,21 @@ async def begin_transaction( ) await self._open_transaction( - tx_cls=AsyncTransaction, + tx_cls=AsyncTransaction, api=TelemetryAPI.TX, access_mode=self._config.default_access_mode, metadata=metadata, timeout=timeout ) return t.cast(AsyncTransaction, self._transaction) - async def _run_transaction( self, access_mode: str, + api: TelemetryAPI, transaction_function: t.Callable[ te.Concatenate[AsyncManagedTransaction, _P], t.Awaitable[_R] ], - *args: _P.args, **kwargs: _P.kwargs + args: _P.args, kwargs: _P.kwargs ) -> _R: self._check_state() if not callable(transaction_function): @@ -503,6 +517,12 @@ async def _run_transaction( self._config.retry_delay_jitter_factor ) + telemetry_sent = False + + def api_success_cb(meta): + nonlocal telemetry_sent + telemetry_sent = True + errors = [] t0: float = -1 # Timer @@ -511,8 +531,9 @@ async def _run_transaction( try: await self._open_transaction( tx_cls=AsyncManagedTransaction, + api=None if telemetry_sent else api, access_mode=access_mode, metadata=metadata, - timeout=timeout + timeout=timeout, api_success_cb=api_success_cb, ) assert isinstance(self._transaction, AsyncManagedTransaction) tx = self._transaction @@ -626,7 +647,8 @@ async def get_two_tx(tx): .. versionadded:: 5.0 """ return await self._run_transaction( - READ_ACCESS, transaction_function, *args, **kwargs + READ_ACCESS, TelemetryAPI.TX_FUNC, + transaction_function, args, kwargs ) # TODO: 6.0 - Remove this method @@ -664,7 +686,8 @@ async def read_transaction( Method was renamed to :meth:`.execute_read`. """ return await self._run_transaction( - READ_ACCESS, transaction_function, *args, **kwargs + READ_ACCESS, TelemetryAPI.TX_FUNC, + transaction_function, args, kwargs ) async def execute_write( @@ -718,7 +741,8 @@ async def create_node_tx(tx, name): .. versionadded:: 5.0 """ return await self._run_transaction( - WRITE_ACCESS, transaction_function, *args, **kwargs + WRITE_ACCESS, TelemetryAPI.TX_FUNC, + transaction_function, args, kwargs ) # TODO: 6.0 - Remove this method @@ -756,7 +780,8 @@ async def write_transaction( Method was renamed to :meth:`.execute_write`. """ return await self._run_transaction( - WRITE_ACCESS, transaction_function, *args, **kwargs + WRITE_ACCESS, TelemetryAPI.TX_FUNC, + transaction_function, args, kwargs ) diff --git a/src/neo4j/_conf.py b/src/neo4j/_conf.py index 9d743f579..e17b9ef36 100644 --- a/src/neo4j/_conf.py +++ b/src/neo4j/_conf.py @@ -411,6 +411,9 @@ class PoolConfig(Config): #: List of notification categories for the server to ignore notifications_disabled_categories = None + #: Opt-Out of telemetry collection + telemetry_disabled = False + def get_ssl_context(self): if self.ssl_context is not None: return self.ssl_context diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 754c2138c..7e18adf2a 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -31,7 +31,10 @@ T_NotificationMinimumSeverity, ) -from .._api import RoutingControl +from .._api import ( + RoutingControl, + TelemetryAPI, +) from .._async_compat.util import Util from .._conf import ( Config, @@ -70,6 +73,7 @@ URI_SCHEME_NEO4J, URI_SCHEME_NEO4J_SECURE, URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, + WRITE_ACCESS, ) from ..auth_management import ( AuthManager, @@ -158,7 +162,8 @@ def driver( fetch_size: int = ..., impersonated_user: t.Optional[str] = ..., bookmark_manager: t.Union[BookmarkManager, - BookmarkManager, None] = ... + BookmarkManager, None] = ..., + telemetry_disabled: bool = ..., ) -> Driver: ... @@ -865,15 +870,16 @@ def example(driver: neo4j.Driver) -> neo4j.Record:: session = self._session(session_config) with session: if routing_ == RoutingControl.WRITE: - executor = session.execute_write + access_mode = WRITE_ACCESS elif routing_ == RoutingControl.READ: - executor = session.execute_read + access_mode = READ_ACCESS else: raise ValueError("Invalid routing control value: %r" % routing_) with session._pipelined_begin: - return executor( - _work, query_, parameters, result_transformer_ + return session._run_transaction( + access_mode, TelemetryAPI.DRIVER, + _work, (query_, parameters, result_transformer_), {} ) @property diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 24336fa8e..22c1efcaa 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -25,6 +25,7 @@ from logging import getLogger from time import perf_counter +from ..._api import TelemetryAPI from ..._async_compat.network import BoltSocket from ..._async_compat.util import Util from ..._codec.hydration import v1 as hydration_v1 @@ -134,7 +135,8 @@ class Bolt: def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, auth_manager=None, user_agent=None, routing_context=None, notifications_min_severity=None, - notifications_disabled_categories=None): + notifications_disabled_categories=None, + telemetry_disabled=False): self.unresolved_address = unresolved_address self.socket = sock self.local_port = self.socket.getsockname()[1] @@ -172,6 +174,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, self.auth = auth self.auth_dict = self._to_auth_dict(auth) self.auth_manager = auth_manager + self.telemetry_disabled = telemetry_disabled self.notifications_min_severity = notifications_min_severity self.notifications_disabled_categories = \ @@ -280,6 +283,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x1, Bolt5x2, Bolt5x3, + Bolt5x4, ) handlers = { @@ -293,6 +297,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x1.PROTOCOL_VERSION: Bolt5x1, Bolt5x2.PROTOCOL_VERSION: Bolt5x2, Bolt5x3.PROTOCOL_VERSION: Bolt5x3, + Bolt5x4.PROTOCOL_VERSION: Bolt5x4, } if protocol_version is None: @@ -407,7 +412,10 @@ def open( # Carry out Bolt subclass imports locally to avoid circular dependency # issues. - if protocol_version == (5, 3): + if protocol_version == (5, 4): + from ._bolt5 import Bolt5x4 + bolt_cls = Bolt5x4 + elif protocol_version == (5, 3): from ._bolt5 import Bolt5x3 bolt_cls = Bolt5x3 elif protocol_version == (5, 2): @@ -471,7 +479,8 @@ def open( routing_context=routing_context, notifications_min_severity=pool_config.notifications_min_severity, notifications_disabled_categories= - pool_config.notifications_disabled_categories + pool_config.notifications_disabled_categories, + telemetry_disabled=pool_config.telemetry_disabled, ) try: @@ -555,7 +564,6 @@ def re_auth( hydration_hooks=hydration_hooks) return True - @abc.abstractmethod def route( self, database=None, imp_user=None, bookmarks=None, @@ -584,6 +592,23 @@ def route( """ pass + @abc.abstractmethod + def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, + hydration_hooks=None, **handlers) -> None: + """Send telemetry information about the API usage to the server. + + :param api: the API used. + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. + """ + pass + @abc.abstractmethod def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, imp_user=None, diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index eff579493..ca8d4d36a 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -23,6 +23,7 @@ from logging import getLogger from ssl import SSLSocket +from ..._api import TelemetryAPI from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, @@ -225,6 +226,11 @@ def logoff(self, dehydration_hooks=None, hydration_hooks=None): """Append a LOGOFF message to the outgoing queue.""" self.assert_re_auth_support() + def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, + hydration_hooks=None, **handlers) -> None: + # TELEMETRY not support by this protocol version, so we ignore it. + pass + def route( self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 44f012f13..58bdd8b1e 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -19,6 +19,7 @@ from logging import getLogger from ssl import SSLSocket +from ..._api import TelemetryAPI from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, @@ -146,6 +147,11 @@ def logoff(self, dehydration_hooks=None, hydration_hooks=None): """Append a LOGOFF message to the outgoing queue.""" self.assert_re_auth_support() + def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, + hydration_hooks=None, **handlers) -> None: + # TELEMETRY not support by this protocol version, so we ignore it. + pass + def route( self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 73882c17d..d1668784c 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -21,6 +21,7 @@ from logging import getLogger from ssl import SSLSocket +from ..._api import TelemetryAPI from ..._codec.hydration import v2 as hydration_v2 from ..._exceptions import BoltProtocolError from ..._meta import BOLT_AGENT_DICT @@ -168,6 +169,11 @@ def logoff(self, dehydration_hooks=None, hydration_hooks=None): """Append a LOGOFF message to the outgoing queue.""" self.assert_re_auth_support() + def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, + hydration_hooks=None, **handlers) -> None: + # TELEMETRY not support by this protocol version, so we ignore it. + pass + def route(self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None): routing_context = self.routing_context or {} @@ -665,3 +671,22 @@ def get_base_headers(self): headers = super().get_base_headers() headers["bolt_agent"] = BOLT_AGENT_DICT return headers + + +class Bolt5x4(Bolt5x3): + + PROTOCOL_VERSION = Version(5, 4) + + def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, + hydration_hooks=None, **handlers) -> None: + if ( + self.telemetry_disabled + or not self.configuration_hints.get("telemetry.enabled", False) + ): + return + api_raw = int(api) + log.debug("[#%04X] C: TELEMETRY %i # (%r)", + self.local_port, api_raw, api) + self._append(b"\x54", (api_raw,), + Response(self, "telemetry", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index a60934b7c..7ad782194 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -24,6 +24,7 @@ from random import random from time import perf_counter +from ..._api import TelemetryAPI from ..._async_compat import sleep from ..._async_compat.util import Util from ..._conf import SessionConfig @@ -300,8 +301,10 @@ def run( if not self._connection: self._connect(self._config.default_access_mode) + assert self._connection is not None cx = self._connection + cx.telemetry(TelemetryAPI.AUTO_COMMIT) self._auto_result = Result( cx, self._config.fetch_size, self._result_closed, self._result_error @@ -406,9 +409,20 @@ def _transaction_cancel_handler(self): ) def _open_transaction( - self, *, tx_cls, access_mode, metadata=None, timeout=None - ): + self, + *, + tx_cls: t.Callable[ + ..., t.Union[Transaction, ManagedTransaction] + ], + access_mode, api: t.Optional[TelemetryAPI], + metadata=None, + timeout=None, + api_success_cb: t.Optional[t.Callable[[dict], None]] = None, + ) -> None: self._connect(access_mode=access_mode) + assert self._connection is not None + if api is not None: + self._connection.telemetry(api, on_success=api_success_cb) self._transaction = tx_cls( self._connection, self._config.fetch_size, self._transaction_closed_handler, @@ -474,21 +488,21 @@ def begin_transaction( ) self._open_transaction( - tx_cls=Transaction, + tx_cls=Transaction, api=TelemetryAPI.TX, access_mode=self._config.default_access_mode, metadata=metadata, timeout=timeout ) return t.cast(Transaction, self._transaction) - def _run_transaction( self, access_mode: str, + api: TelemetryAPI, transaction_function: t.Callable[ te.Concatenate[ManagedTransaction, _P], t.Union[_R] ], - *args: _P.args, **kwargs: _P.kwargs + args: _P.args, kwargs: _P.kwargs ) -> _R: self._check_state() if not callable(transaction_function): @@ -503,6 +517,12 @@ def _run_transaction( self._config.retry_delay_jitter_factor ) + telemetry_sent = False + + def api_success_cb(meta): + nonlocal telemetry_sent + telemetry_sent = True + errors = [] t0: float = -1 # Timer @@ -511,8 +531,9 @@ def _run_transaction( try: self._open_transaction( tx_cls=ManagedTransaction, + api=None if telemetry_sent else api, access_mode=access_mode, metadata=metadata, - timeout=timeout + timeout=timeout, api_success_cb=api_success_cb, ) assert isinstance(self._transaction, ManagedTransaction) tx = self._transaction @@ -626,7 +647,8 @@ def get_two_tx(tx): .. versionadded:: 5.0 """ return self._run_transaction( - READ_ACCESS, transaction_function, *args, **kwargs + READ_ACCESS, TelemetryAPI.TX_FUNC, + transaction_function, args, kwargs ) # TODO: 6.0 - Remove this method @@ -664,7 +686,8 @@ def read_transaction( Method was renamed to :meth:`.execute_read`. """ return self._run_transaction( - READ_ACCESS, transaction_function, *args, **kwargs + READ_ACCESS, TelemetryAPI.TX_FUNC, + transaction_function, args, kwargs ) def execute_write( @@ -718,7 +741,8 @@ def create_node_tx(tx, name): .. versionadded:: 5.0 """ return self._run_transaction( - WRITE_ACCESS, transaction_function, *args, **kwargs + WRITE_ACCESS, TelemetryAPI.TX_FUNC, + transaction_function, args, kwargs ) # TODO: 6.0 - Remove this method @@ -756,7 +780,8 @@ def write_transaction( Method was renamed to :meth:`.execute_write`. """ return self._run_transaction( - WRITE_ACCESS, transaction_function, *args, **kwargs + WRITE_ACCESS, TelemetryAPI.TX_FUNC, + transaction_function, args, kwargs ) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index f272986f6..1ad3cbdb2 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -134,6 +134,7 @@ async def NewDriver(backend, data): for (conf_name, data_name) in ( ("max_connection_pool_size", "maxConnectionPoolSize"), ("fetch_size", "fetchSize"), + ("telemetry_disabled", "telemetryDisabled") ): if data.get(data_name): kwargs[conf_name] = data[data_name] diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 7dd69cf7b..9932021e3 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -134,6 +134,7 @@ def NewDriver(backend, data): for (conf_name, data_name) in ( ("max_connection_pool_size", "maxConnectionPoolSize"), ("fetch_size", "fetchSize"), + ("telemetry_disabled", "telemetryDisabled") ): if data.get(data_name): kwargs[conf_name] = data[data_name] diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 68f776bbd..5957054e9 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -56,6 +56,7 @@ "Feature:Bolt:5.1": true, "Feature:Bolt:5.2": true, "Feature:Bolt:5.3": true, + "Feature:Bolt:5.4": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index eea7bde04..1a55b10b7 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -39,7 +39,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), } protocol_handlers = AsyncBolt.protocol_handlers() @@ -65,7 +65,8 @@ def test_class_method_protocol_handlers(): ((5, 1), 1), ((5, 2), 1), ((5, 3), 1), - ((5, 4), 0), + ((5, 4), 1), + ((5, 5), 0), ((6, 0), 0), ] ) @@ -85,7 +86,7 @@ def test_class_method_protocol_handlers_with_invalid_protocol_version(): # [bolt-version-bump] search tag when changing bolt version support def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() - assert (b"\x00\x03\x03\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + assert (b"\x00\x04\x04\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" == handshake) @@ -132,6 +133,7 @@ async def test_cancel_hello_in_open(mocker, none_auth): ((5, 1), "neo4j._async.io._bolt5.AsyncBolt5x1"), ((5, 2), "neo4j._async.io._bolt5.AsyncBolt5x2"), ((5, 3), "neo4j._async.io._bolt5.AsyncBolt5x3"), + ((5, 4), "neo4j._async.io._bolt5.AsyncBolt5x4"), ), ) @mark_async_test @@ -164,13 +166,15 @@ async def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 4), + (5, 5), (6, 0), )) @mark_async_test async def test_failing_version_negotiation(mocker, bolt_version, none_auth): - supported_protocols = \ - "('3.0', '4.1', '4.2', '4.3', '4.4', '5.0', '5.1', '5.2', '5.3')" + supported_protocols = ( + "('3.0', '4.1', '4.2', '4.3', '4.4', " + "'5.0', '5.1', '5.2', '5.3', '5.4')" + ) address = ("localhost", 7687) socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) @@ -188,7 +192,6 @@ async def test_failing_version_negotiation(mocker, bolt_version, none_auth): assert exc.match(supported_protocols) - @AsyncTestDecorators.mark_async_only_test async def test_cancel_manager_in_open(mocker): address = ("localhost", 7687) diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index 635d9e1e4..62fb5fc79 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -14,6 +14,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + import contextlib import itertools import logging @@ -21,6 +23,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._async.io._bolt3 import AsyncBolt3 from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT @@ -97,6 +100,28 @@ async def test_simple_pull(fake_socket): assert len(fields) == 0 +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) + connection = AsyncBolt3( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + with pytest.raises(OSError): + await socket.pop_message() + + @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_async_test async def test_hint_recv_timeout_seconds_gets_ignored( diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index b667ba814..970b13f8b 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -16,13 +16,13 @@ # limitations under the License. -import contextlib import itertools import logging import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._async.io._bolt4 import AsyncBolt4x0 from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT @@ -195,6 +195,28 @@ async def test_n_and_qid_extras_in_pull(fake_socket): assert fields[0] == {"n": 666, "qid": 777} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) + connection = AsyncBolt4x0( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + with pytest.raises(OSError): + await socket.pop_message() + + @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_async_test async def test_hint_recv_timeout_seconds_gets_ignored( diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index a28580463..5f3cd1de0 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._async.io._bolt4 import AsyncBolt4x1 from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT @@ -212,6 +213,28 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) + connection = AsyncBolt4x1( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + with pytest.raises(OSError): + await socket.pop_message() + + @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_async_test async def test_hint_recv_timeout_seconds_gets_ignored( diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index 41d732f20..187d7e9e0 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._async.io._bolt4 import AsyncBolt4x2 from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT @@ -212,6 +213,28 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) + connection = AsyncBolt4x2( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + with pytest.raises(OSError): + await socket.pop_message() + + @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_async_test async def test_hint_recv_timeout_seconds_gets_ignored( diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index a05656d5b..a452e4cbe 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._async.io._bolt4 import AsyncBolt4x3 from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT @@ -212,6 +213,28 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) + connection = AsyncBolt4x3( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + with pytest.raises(OSError): + await socket.pop_message() + + @pytest.mark.parametrize(("hints", "valid"), ( ({"connection.recv_timeout_seconds": 1}, True), ({"connection.recv_timeout_seconds": 42}, True), diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index 9923c5047..76117f55b 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._async.io._bolt4 import AsyncBolt4x4 from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT @@ -226,6 +227,28 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) + connection = AsyncBolt4x4( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + with pytest.raises(OSError): + await socket.pop_message() + + @pytest.mark.parametrize(("hints", "valid"), ( ({"connection.recv_timeout_seconds": 1}, True), ({"connection.recv_timeout_seconds": 42}, True), diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index c6f4d3297..033676d9a 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._async.io._bolt5 import AsyncBolt5x0 from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT @@ -226,6 +227,28 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) + connection = AsyncBolt5x0( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + with pytest.raises(OSError): + await socket.pop_message() + + @pytest.mark.parametrize(("hints", "valid"), ( ({"connection.recv_timeout_seconds": 1}, True), ({"connection.recv_timeout_seconds": 42}, True), diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 2ee26b130..6bff1868d 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -23,6 +23,7 @@ import neo4j import neo4j.exceptions +from neo4j._api import TelemetryAPI from neo4j._async.io._bolt5 import AsyncBolt5x1 from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT @@ -239,6 +240,28 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + with pytest.raises(OSError): + await socket.pop_message() + + async def _assert_logon_message(sockets, auth): tag, fields = await sockets.server.pop_message() assert tag == b"\x6A" # LOGON diff --git a/tests/unit/async_/io/test_class_bolt5x2.py b/tests/unit/async_/io/test_class_bolt5x2.py index 429b02687..f22b9c173 100644 --- a/tests/unit/async_/io/test_class_bolt5x2.py +++ b/tests/unit/async_/io/test_class_bolt5x2.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._async.io._bolt5 import AsyncBolt5x2 from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT @@ -226,6 +227,28 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) + connection = AsyncBolt5x2( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + with pytest.raises(OSError): + await socket.pop_message() + + async def _assert_logon_message(sockets, auth): tag, fields = await sockets.server.pop_message() assert tag == b"\x6A" # LOGON diff --git a/tests/unit/async_/io/test_class_bolt5x3.py b/tests/unit/async_/io/test_class_bolt5x3.py index a11c2ccca..462639c2e 100644 --- a/tests/unit/async_/io/test_class_bolt5x3.py +++ b/tests/unit/async_/io/test_class_bolt5x3.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._async.io._bolt5 import AsyncBolt5x3 from neo4j._conf import PoolConfig from neo4j._meta import ( @@ -229,6 +230,28 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x3.UNPACKER_CLS) + connection = AsyncBolt5x3( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + with pytest.raises(OSError): + await socket.pop_message() + + @pytest.mark.parametrize(("hints", "valid"), ( ({"connection.recv_timeout_seconds": 1}, True), ({"connection.recv_timeout_seconds": 42}, True), diff --git a/tests/unit/async_/io/test_class_bolt5x4.py b/tests/unit/async_/io/test_class_bolt5x4.py new file mode 100644 index 000000000..a06bd1f8e --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt5x4.py @@ -0,0 +1,589 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._async.io._bolt5 import AsyncBolt5x4 +from neo4j._conf import PoolConfig +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) + +from ...._async_compat import mark_async_test + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = AsyncBolt5x4(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = AsyncBolt5x4(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = AsyncBolt5x4(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) + connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) + connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) + connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) + connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) + connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) + connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) + connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) + connection = AsyncBolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x4.PACKER_CLS, + unpacker_cls=AsyncBolt5x4.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x4( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) + connection = AsyncBolt5x4( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = await socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + await socket.pop_message() + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x4.PACKER_CLS, + unpacker_cls=AsyncBolt5x4.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x4( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x4.PACKER_CLS, + unpacker_cls=AsyncBolt5x4.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x4( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + + +@pytest.mark.parametrize(("method", "args", "extra_idx"), ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), +)) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2) +) +@pytest.mark.parametrize( + ("cls_dis_cats", "method_dis_cats"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2) +) +@mark_async_test +async def test_supports_notification_filters( + fake_socket, method, args, extra_idx, cls_min_sev, method_min_sev, + cls_dis_cats, method_dis_cats +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x4.UNPACKER_CLS) + connection = AsyncBolt5x4( + address, socket, PoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_categories=cls_dis_cats + ) + method = getattr(connection, method) + + method(*args, notifications_min_severity=method_min_sev, + notifications_disabled_categories=method_dis_cats) + await connection.send_all() + + _, fields = await socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_cats is not None: + expected["notifications_disabled_categories"] = method_dis_cats + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize("dis_cats", + (None, [], ["HINT"], ["HINT", "DEPRECATION"])) +@mark_async_test +async def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_cats +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x4.PACKER_CLS, + unpacker_cls=AsyncBolt5x4.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x4( + address, sockets.client, PoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_categories=dis_cats + ) + + await connection.hello() + + tag, fields = await sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_cats is not None: + expected["notifications_disabled_categories"] = dis_cats + _assert_notifications_in_extra(extra, expected) + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x4.PACKER_CLS, + unpacker_cls=AsyncBolt5x4.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x4( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + tag, fields = await sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x4.PACKER_CLS, + unpacker_cls=AsyncBolt5x4.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x4( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + tag, fields = await sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_async_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ) +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0") + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds") + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds") + ) + ) +) +async def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x4.PACKER_CLS, + unpacker_cls=AsyncBolt5x4.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x4(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + await connection.send_all() + tag, fields = await sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x4.PACKER_CLS, + unpacker_cls=AsyncBolt5x4.UNPACKER_CLS) + connection = AsyncBolt5x4(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index e65e456f4..bee79f798 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -41,6 +41,7 @@ TrustCustomCAs, TrustSystemCAs, ) +from neo4j._api import TelemetryAPI from neo4j._async.driver import _work from neo4j._async.io import ( AsyncBoltPool, @@ -644,9 +645,10 @@ async def test_execute_query_query( session_mock = session_cls_mock.return_value session_mock.__aenter__.assert_awaited_once() session_mock.__aexit__.assert_awaited_once() - session_executor_mock = session_mock.execute_write + session_executor_mock = session_mock._run_transaction session_executor_mock.assert_awaited_once_with( - _work, query, mocker.ANY, mocker.ANY + WRITE_ACCESS, TelemetryAPI.DRIVER, _work, + (query, mocker.ANY, mocker.ANY), {} ) assert res is session_executor_mock.return_value @@ -676,9 +678,10 @@ async def test_execute_query_parameters( session_mock = session_cls_mock.return_value session_mock.__aenter__.assert_awaited_once() session_mock.__aexit__.assert_awaited_once() - session_executor_mock = session_mock.execute_write + session_executor_mock = session_mock._run_transaction session_executor_mock.assert_awaited_once_with( - _work, mocker.ANY, parameters or {}, mocker.ANY + WRITE_ACCESS, TelemetryAPI.DRIVER, _work, + (mocker.ANY, parameters or {}, mocker.ANY), {} ) assert res is session_executor_mock.return_value @@ -702,9 +705,10 @@ async def test_execute_query_keyword_parameters( session_mock = session_cls_mock.return_value session_mock.__aenter__.assert_awaited_once() session_mock.__aexit__.assert_awaited_once() - session_executor_mock = session_mock.execute_write + session_executor_mock = session_mock._run_transaction session_executor_mock.assert_awaited_once_with( - _work, mocker.ANY, parameters or {}, mocker.ANY + WRITE_ACCESS, TelemetryAPI.DRIVER, _work, + (mocker.ANY, parameters or {}, mocker.ANY), {} ) assert res is session_executor_mock.return_value @@ -779,27 +783,28 @@ async def test_execute_query_parameter_precedence( session_mock = session_cls_mock.return_value session_mock.__aenter__.assert_awaited_once() session_mock.__aexit__.assert_awaited_once() - session_executor_mock = session_mock.execute_write + session_executor_mock = session_mock._run_transaction session_executor_mock.assert_awaited_once_with( - _work, mocker.ANY, expected_params, mocker.ANY + WRITE_ACCESS, TelemetryAPI.DRIVER, _work, + (mocker.ANY, expected_params, mocker.ANY), {} ) assert res is session_executor_mock.return_value @pytest.mark.parametrize( - ("routing_mode", "session_executor"), + ("routing_mode", "mode"), ( - (None, "execute_write"), - ("r", "execute_read"), - ("w", "execute_write"), - (neo4j.RoutingControl.READ, "execute_read"), - (neo4j.RoutingControl.WRITE, "execute_write"), + (None, WRITE_ACCESS), + ("r", READ_ACCESS), + ("w", WRITE_ACCESS), + (neo4j.RoutingControl.READ, READ_ACCESS), + (neo4j.RoutingControl.WRITE, WRITE_ACCESS), ) ) @pytest.mark.parametrize("positional", (True, False)) @mark_async_test async def test_execute_query_routing_control( - session_executor: str, positional: bool, + mode: str, positional: bool, routing_mode: t.Union[neo4j.RoutingControl, te.Literal["r", "w"], None], session_cls_mock, mocker ) -> None: @@ -817,9 +822,10 @@ async def test_execute_query_routing_control( session_mock = session_cls_mock.return_value session_mock.__aenter__.assert_awaited_once() session_mock.__aexit__.assert_awaited_once() - session_executor_mock = getattr(session_mock, session_executor) + session_executor_mock = session_mock._run_transaction session_executor_mock.assert_awaited_once_with( - _work, mocker.ANY, mocker.ANY, mocker.ANY + mode, TelemetryAPI.DRIVER, _work, + (mocker.ANY, mocker.ANY, mocker.ANY), {} ) assert res is session_executor_mock.return_value @@ -935,9 +941,10 @@ async def test_execute_query_result_transformer( session_mock = session_cls_mock.return_value session_mock.__aenter__.assert_awaited_once() session_mock.__aexit__.assert_awaited_once() - session_executor_mock = session_mock.execute_write + session_executor_mock = session_mock._run_transaction session_executor_mock.assert_awaited_once_with( - _work, mocker.ANY, mocker.ANY, result_transformer + WRITE_ACCESS, TelemetryAPI.DRIVER, _work, + (mocker.ANY, mocker.ANY, result_transformer), {} ) assert res is session_executor_mock.return_value diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index fdb51e365..00fcd1109 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -27,12 +27,17 @@ Bookmarks, unit_of_work, ) +from neo4j._api import TelemetryAPI from neo4j._async.io import ( AsyncBoltPool, AsyncNeo4jPool, ) from neo4j._conf import SessionConfig -from neo4j.api import AsyncBookmarkManager +from neo4j.api import ( + AsyncBookmarkManager, + READ_ACCESS, + WRITE_ACCESS, +) from ...._async_compat import mark_async_test @@ -566,3 +571,60 @@ async def test_run_notification_disabled_categories(fake_pool, routing): connection_mock.run.assert_called_once() call_kwargs = connection_mock.run.call_args.kwargs assert call_kwargs["notifications_disabled_categories"] is dis_cats + + +@mark_async_test +async def test_session_run_api_telemetry(fake_pool): + async with AsyncSession(fake_pool, SessionConfig()) as session: + await session.run("RETURN 1") + assert len(fake_pool.acquired_connection_mocks) == 1 + connection_mock = fake_pool.acquired_connection_mocks[0] + connection_mock.telemetry.assert_called_once() + call_args = connection_mock.telemetry.call_args.args + assert call_args[0] == TelemetryAPI.AUTO_COMMIT + + +@mark_async_test +async def test_session_unmanaged_transaction_api_telemetry(fake_pool): + async with AsyncSession(fake_pool, SessionConfig()) as session: + await session.begin_transaction() + assert len(fake_pool.acquired_connection_mocks) == 1 + connection_mock = fake_pool.acquired_connection_mocks[0] + connection_mock.telemetry.assert_called_once() + call_args = connection_mock.telemetry.call_args.args + assert call_args[0] == TelemetryAPI.TX + + +@pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction", + "execute_write", "execute_read",)) +@mark_async_test +async def test_session_managed_transaction_api_telemetry(fake_pool, tx_type): + async def work(_): + pass + + async with AsyncSession(fake_pool, SessionConfig()) as session: + with assert_warns_tx_func_deprecation(tx_type): + await getattr(session, tx_type)(work) + assert len(fake_pool.acquired_connection_mocks) == 1 + connection_mock = fake_pool.acquired_connection_mocks[0] + connection_mock.telemetry.assert_called_once() + call_args = connection_mock.telemetry.call_args.args + assert call_args[0] == TelemetryAPI.TX_FUNC + + +@pytest.mark.parametrize("mode", (WRITE_ACCESS, READ_ACCESS)) +@mark_async_test +async def test_session_custom_api_telemetry(fake_pool, mode): + async def work(_): + pass + + async with AsyncSession(fake_pool, SessionConfig()) as session: + await session._run_transaction( + mode, TelemetryAPI.DRIVER, + work, (), {} + ) + assert len(fake_pool.acquired_connection_mocks) == 1 + connection_mock = fake_pool.acquired_connection_mocks[0] + connection_mock.telemetry.assert_called_once() + call_args = connection_mock.telemetry.call_args.args + assert call_args[0] == TelemetryAPI.DRIVER diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index ce8bceb4c..52e1def86 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -58,6 +58,7 @@ "auth": None, "notifications_min_severity": None, "notifications_disabled_categories": None, + "telemetry_disabled": False, } test_session_config = { diff --git a/tests/unit/common/work/test_summary.py b/tests/unit/common/work/test_summary.py index 3d3049356..fbc198318 100644 --- a/tests/unit/common/work/test_summary.py +++ b/tests/unit/common/work/test_summary.py @@ -308,6 +308,7 @@ def test_summary_result_counters(summary_args_kwargs, counters_set) -> None: ((5, 1), "t_first"), ((5, 2), "t_first"), ((5, 3), "t_first"), + ((5, 4), "t_first"), )) def test_summary_result_available_after( summary_args_kwargs, exists, bolt_version, meta_name @@ -338,6 +339,7 @@ def test_summary_result_available_after( ((5, 1), "t_last"), ((5, 2), "t_last"), ((5, 3), "t_last"), + ((5, 4), "t_last"), )) def test_summary_result_consumed_after( summary_args_kwargs, exists, bolt_version, meta_name diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index 9c74816d3..657aebc2e 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -39,7 +39,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), } protocol_handlers = Bolt.protocol_handlers() @@ -65,7 +65,8 @@ def test_class_method_protocol_handlers(): ((5, 1), 1), ((5, 2), 1), ((5, 3), 1), - ((5, 4), 0), + ((5, 4), 1), + ((5, 5), 0), ((6, 0), 0), ] ) @@ -85,7 +86,7 @@ def test_class_method_protocol_handlers_with_invalid_protocol_version(): # [bolt-version-bump] search tag when changing bolt version support def test_class_method_get_handshake(): handshake = Bolt.get_handshake() - assert (b"\x00\x03\x03\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + assert (b"\x00\x04\x04\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" == handshake) @@ -132,6 +133,7 @@ def test_cancel_hello_in_open(mocker, none_auth): ((5, 1), "neo4j._sync.io._bolt5.Bolt5x1"), ((5, 2), "neo4j._sync.io._bolt5.Bolt5x2"), ((5, 3), "neo4j._sync.io._bolt5.Bolt5x3"), + ((5, 4), "neo4j._sync.io._bolt5.Bolt5x4"), ), ) @mark_sync_test @@ -164,13 +166,15 @@ def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 4), + (5, 5), (6, 0), )) @mark_sync_test def test_failing_version_negotiation(mocker, bolt_version, none_auth): - supported_protocols = \ - "('3.0', '4.1', '4.2', '4.3', '4.4', '5.0', '5.1', '5.2', '5.3')" + supported_protocols = ( + "('3.0', '4.1', '4.2', '4.3', '4.4', " + "'5.0', '5.1', '5.2', '5.3', '5.4')" + ) address = ("localhost", 7687) socket_mock = mocker.MagicMock(spec=BoltSocket) @@ -188,7 +192,6 @@ def test_failing_version_negotiation(mocker, bolt_version, none_auth): assert exc.match(supported_protocols) - @TestDecorators.mark_async_only_test def test_cancel_manager_in_open(mocker): address = ("localhost", 7687) diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index 0ebdfb144..f9426d5f1 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -14,6 +14,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + import contextlib import itertools import logging @@ -21,6 +23,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j._sync.io._bolt3 import Bolt3 @@ -97,6 +100,28 @@ def test_simple_pull(fake_socket): assert len(fields) == 0 +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt3.UNPACKER_CLS) + connection = Bolt3( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + with pytest.raises(OSError): + socket.pop_message() + + @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_sync_test def test_hint_recv_timeout_seconds_gets_ignored( diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index 497723cd6..9f97eb365 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -16,13 +16,13 @@ # limitations under the License. -import contextlib import itertools import logging import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j._sync.io._bolt4 import Bolt4x0 @@ -195,6 +195,28 @@ def test_n_and_qid_extras_in_pull(fake_socket): assert fields[0] == {"n": 666, "qid": 777} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) + connection = Bolt4x0( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + with pytest.raises(OSError): + socket.pop_message() + + @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_sync_test def test_hint_recv_timeout_seconds_gets_ignored( diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index a6f4adfb7..64bfb6629 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j._sync.io._bolt4 import Bolt4x1 @@ -212,6 +213,28 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) + connection = Bolt4x1( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + with pytest.raises(OSError): + socket.pop_message() + + @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_sync_test def test_hint_recv_timeout_seconds_gets_ignored( diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index 820fb7333..98a64a440 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j._sync.io._bolt4 import Bolt4x2 @@ -212,6 +213,28 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) + connection = Bolt4x2( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + with pytest.raises(OSError): + socket.pop_message() + + @pytest.mark.parametrize("recv_timeout", (1, -1)) @mark_sync_test def test_hint_recv_timeout_seconds_gets_ignored( diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index 02cf91393..957d9e0c3 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j._sync.io._bolt4 import Bolt4x3 @@ -212,6 +213,28 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) + connection = Bolt4x3( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + with pytest.raises(OSError): + socket.pop_message() + + @pytest.mark.parametrize(("hints", "valid"), ( ({"connection.recv_timeout_seconds": 1}, True), ({"connection.recv_timeout_seconds": 42}, True), diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index b417a4c19..0d01782da 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j._sync.io._bolt4 import Bolt4x4 @@ -226,6 +227,28 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) + connection = Bolt4x4( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + with pytest.raises(OSError): + socket.pop_message() + + @pytest.mark.parametrize(("hints", "valid"), ( ({"connection.recv_timeout_seconds": 1}, True), ({"connection.recv_timeout_seconds": 42}, True), diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py index 3beff0350..d261d9aaf 100644 --- a/tests/unit/sync/io/test_class_bolt5x0.py +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j._sync.io._bolt5 import Bolt5x0 @@ -226,6 +227,28 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) + connection = Bolt5x0( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + with pytest.raises(OSError): + socket.pop_message() + + @pytest.mark.parametrize(("hints", "valid"), ( ({"connection.recv_timeout_seconds": 1}, True), ({"connection.recv_timeout_seconds": 42}, True), diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index 02178cfea..f5b71bc84 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -23,6 +23,7 @@ import neo4j import neo4j.exceptions +from neo4j._api import TelemetryAPI from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j._sync.io._bolt5 import Bolt5x1 @@ -239,6 +240,28 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + with pytest.raises(OSError): + socket.pop_message() + + def _assert_logon_message(sockets, auth): tag, fields = sockets.server.pop_message() assert tag == b"\x6A" # LOGON diff --git a/tests/unit/sync/io/test_class_bolt5x2.py b/tests/unit/sync/io/test_class_bolt5x2.py index 99f8daee6..047c32e6a 100644 --- a/tests/unit/sync/io/test_class_bolt5x2.py +++ b/tests/unit/sync/io/test_class_bolt5x2.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._conf import PoolConfig from neo4j._meta import USER_AGENT from neo4j._sync.io._bolt5 import Bolt5x2 @@ -226,6 +227,28 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x2.UNPACKER_CLS) + connection = Bolt5x2( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + with pytest.raises(OSError): + socket.pop_message() + + def _assert_logon_message(sockets, auth): tag, fields = sockets.server.pop_message() assert tag == b"\x6A" # LOGON diff --git a/tests/unit/sync/io/test_class_bolt5x3.py b/tests/unit/sync/io/test_class_bolt5x3.py index 7a3dd884c..d573ffb8e 100644 --- a/tests/unit/sync/io/test_class_bolt5x3.py +++ b/tests/unit/sync/io/test_class_bolt5x3.py @@ -22,6 +22,7 @@ import pytest import neo4j +from neo4j._api import TelemetryAPI from neo4j._conf import PoolConfig from neo4j._meta import ( BOLT_AGENT_DICT, @@ -229,6 +230,28 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x3.UNPACKER_CLS) + connection = Bolt5x3( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + with pytest.raises(OSError): + socket.pop_message() + + @pytest.mark.parametrize(("hints", "valid"), ( ({"connection.recv_timeout_seconds": 1}, True), ({"connection.recv_timeout_seconds": 42}, True), diff --git a/tests/unit/sync/io/test_class_bolt5x4.py b/tests/unit/sync/io/test_class_bolt5x4.py new file mode 100644 index 000000000..5b5570d71 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt5x4.py @@ -0,0 +1,589 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._conf import PoolConfig +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j._sync.io._bolt5 import Bolt5x4 + +from ...._async_compat import mark_sync_test + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = Bolt5x4(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = Bolt5x4(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = Bolt5x4(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x4.UNPACKER_CLS) + connection = Bolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x4.UNPACKER_CLS) + connection = Bolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x4.UNPACKER_CLS) + connection = Bolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x4.UNPACKER_CLS) + connection = Bolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x4.UNPACKER_CLS) + connection = Bolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x4.UNPACKER_CLS) + connection = Bolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x4.UNPACKER_CLS) + connection = Bolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x4.UNPACKER_CLS) + connection = Bolt5x4(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x4.PACKER_CLS, + unpacker_cls=Bolt5x4.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x4( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x4.UNPACKER_CLS) + connection = Bolt5x4( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + socket.pop_message() + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x4.PACKER_CLS, + unpacker_cls=Bolt5x4.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x4( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x4.PACKER_CLS, + unpacker_cls=Bolt5x4.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x4( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + + +@pytest.mark.parametrize(("method", "args", "extra_idx"), ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), +)) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2) +) +@pytest.mark.parametrize( + ("cls_dis_cats", "method_dis_cats"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2) +) +@mark_sync_test +def test_supports_notification_filters( + fake_socket, method, args, extra_idx, cls_min_sev, method_min_sev, + cls_dis_cats, method_dis_cats +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x4.UNPACKER_CLS) + connection = Bolt5x4( + address, socket, PoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_categories=cls_dis_cats + ) + method = getattr(connection, method) + + method(*args, notifications_min_severity=method_min_sev, + notifications_disabled_categories=method_dis_cats) + connection.send_all() + + _, fields = socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_cats is not None: + expected["notifications_disabled_categories"] = method_dis_cats + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize("dis_cats", + (None, [], ["HINT"], ["HINT", "DEPRECATION"])) +@mark_sync_test +def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_cats +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x4.PACKER_CLS, + unpacker_cls=Bolt5x4.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x4( + address, sockets.client, PoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_categories=dis_cats + ) + + connection.hello() + + tag, fields = sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_cats is not None: + expected["notifications_disabled_categories"] = dis_cats + _assert_notifications_in_extra(extra, expected) + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x4.PACKER_CLS, + unpacker_cls=Bolt5x4.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x4( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + tag, fields = sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x4.PACKER_CLS, + unpacker_cls=Bolt5x4.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x4( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + tag, fields = sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_sync_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ) +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0") + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds") + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds") + ) + ) +) +def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x4.PACKER_CLS, + unpacker_cls=Bolt5x4.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x4(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x4.PACKER_CLS, + unpacker_cls=Bolt5x4.UNPACKER_CLS) + connection = Bolt5x4(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index 18e1afa69..74faa8b64 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -41,6 +41,7 @@ TrustCustomCAs, TrustSystemCAs, ) +from neo4j._api import TelemetryAPI from neo4j._conf import ( PoolConfig, SessionConfig, @@ -643,9 +644,10 @@ def test_execute_query_query( session_mock = session_cls_mock.return_value session_mock.__enter__.assert_called_once() session_mock.__exit__.assert_called_once() - session_executor_mock = session_mock.execute_write + session_executor_mock = session_mock._run_transaction session_executor_mock.assert_called_once_with( - _work, query, mocker.ANY, mocker.ANY + WRITE_ACCESS, TelemetryAPI.DRIVER, _work, + (query, mocker.ANY, mocker.ANY), {} ) assert res is session_executor_mock.return_value @@ -675,9 +677,10 @@ def test_execute_query_parameters( session_mock = session_cls_mock.return_value session_mock.__enter__.assert_called_once() session_mock.__exit__.assert_called_once() - session_executor_mock = session_mock.execute_write + session_executor_mock = session_mock._run_transaction session_executor_mock.assert_called_once_with( - _work, mocker.ANY, parameters or {}, mocker.ANY + WRITE_ACCESS, TelemetryAPI.DRIVER, _work, + (mocker.ANY, parameters or {}, mocker.ANY), {} ) assert res is session_executor_mock.return_value @@ -701,9 +704,10 @@ def test_execute_query_keyword_parameters( session_mock = session_cls_mock.return_value session_mock.__enter__.assert_called_once() session_mock.__exit__.assert_called_once() - session_executor_mock = session_mock.execute_write + session_executor_mock = session_mock._run_transaction session_executor_mock.assert_called_once_with( - _work, mocker.ANY, parameters or {}, mocker.ANY + WRITE_ACCESS, TelemetryAPI.DRIVER, _work, + (mocker.ANY, parameters or {}, mocker.ANY), {} ) assert res is session_executor_mock.return_value @@ -778,27 +782,28 @@ def test_execute_query_parameter_precedence( session_mock = session_cls_mock.return_value session_mock.__enter__.assert_called_once() session_mock.__exit__.assert_called_once() - session_executor_mock = session_mock.execute_write + session_executor_mock = session_mock._run_transaction session_executor_mock.assert_called_once_with( - _work, mocker.ANY, expected_params, mocker.ANY + WRITE_ACCESS, TelemetryAPI.DRIVER, _work, + (mocker.ANY, expected_params, mocker.ANY), {} ) assert res is session_executor_mock.return_value @pytest.mark.parametrize( - ("routing_mode", "session_executor"), + ("routing_mode", "mode"), ( - (None, "execute_write"), - ("r", "execute_read"), - ("w", "execute_write"), - (neo4j.RoutingControl.READ, "execute_read"), - (neo4j.RoutingControl.WRITE, "execute_write"), + (None, WRITE_ACCESS), + ("r", READ_ACCESS), + ("w", WRITE_ACCESS), + (neo4j.RoutingControl.READ, READ_ACCESS), + (neo4j.RoutingControl.WRITE, WRITE_ACCESS), ) ) @pytest.mark.parametrize("positional", (True, False)) @mark_sync_test def test_execute_query_routing_control( - session_executor: str, positional: bool, + mode: str, positional: bool, routing_mode: t.Union[neo4j.RoutingControl, te.Literal["r", "w"], None], session_cls_mock, mocker ) -> None: @@ -816,9 +821,10 @@ def test_execute_query_routing_control( session_mock = session_cls_mock.return_value session_mock.__enter__.assert_called_once() session_mock.__exit__.assert_called_once() - session_executor_mock = getattr(session_mock, session_executor) + session_executor_mock = session_mock._run_transaction session_executor_mock.assert_called_once_with( - _work, mocker.ANY, mocker.ANY, mocker.ANY + mode, TelemetryAPI.DRIVER, _work, + (mocker.ANY, mocker.ANY, mocker.ANY), {} ) assert res is session_executor_mock.return_value @@ -934,9 +940,10 @@ def test_execute_query_result_transformer( session_mock = session_cls_mock.return_value session_mock.__enter__.assert_called_once() session_mock.__exit__.assert_called_once() - session_executor_mock = session_mock.execute_write + session_executor_mock = session_mock._run_transaction session_executor_mock.assert_called_once_with( - _work, mocker.ANY, mocker.ANY, result_transformer + WRITE_ACCESS, TelemetryAPI.DRIVER, _work, + (mocker.ANY, mocker.ANY, result_transformer), {} ) assert res is session_executor_mock.return_value diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index d54581bdc..27aae6ab7 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -27,12 +27,17 @@ Transaction, unit_of_work, ) +from neo4j._api import TelemetryAPI from neo4j._conf import SessionConfig from neo4j._sync.io import ( BoltPool, Neo4jPool, ) -from neo4j.api import BookmarkManager +from neo4j.api import ( + BookmarkManager, + READ_ACCESS, + WRITE_ACCESS, +) from ...._async_compat import mark_sync_test @@ -566,3 +571,60 @@ def test_run_notification_disabled_categories(fake_pool, routing): connection_mock.run.assert_called_once() call_kwargs = connection_mock.run.call_args.kwargs assert call_kwargs["notifications_disabled_categories"] is dis_cats + + +@mark_sync_test +def test_session_run_api_telemetry(fake_pool): + with Session(fake_pool, SessionConfig()) as session: + session.run("RETURN 1") + assert len(fake_pool.acquired_connection_mocks) == 1 + connection_mock = fake_pool.acquired_connection_mocks[0] + connection_mock.telemetry.assert_called_once() + call_args = connection_mock.telemetry.call_args.args + assert call_args[0] == TelemetryAPI.AUTO_COMMIT + + +@mark_sync_test +def test_session_unmanaged_transaction_api_telemetry(fake_pool): + with Session(fake_pool, SessionConfig()) as session: + session.begin_transaction() + assert len(fake_pool.acquired_connection_mocks) == 1 + connection_mock = fake_pool.acquired_connection_mocks[0] + connection_mock.telemetry.assert_called_once() + call_args = connection_mock.telemetry.call_args.args + assert call_args[0] == TelemetryAPI.TX + + +@pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction", + "execute_write", "execute_read",)) +@mark_sync_test +def test_session_managed_transaction_api_telemetry(fake_pool, tx_type): + def work(_): + pass + + with Session(fake_pool, SessionConfig()) as session: + with assert_warns_tx_func_deprecation(tx_type): + getattr(session, tx_type)(work) + assert len(fake_pool.acquired_connection_mocks) == 1 + connection_mock = fake_pool.acquired_connection_mocks[0] + connection_mock.telemetry.assert_called_once() + call_args = connection_mock.telemetry.call_args.args + assert call_args[0] == TelemetryAPI.TX_FUNC + + +@pytest.mark.parametrize("mode", (WRITE_ACCESS, READ_ACCESS)) +@mark_sync_test +def test_session_custom_api_telemetry(fake_pool, mode): + def work(_): + pass + + with Session(fake_pool, SessionConfig()) as session: + session._run_transaction( + mode, TelemetryAPI.DRIVER, + work, (), {} + ) + assert len(fake_pool.acquired_connection_mocks) == 1 + connection_mock = fake_pool.acquired_connection_mocks[0] + connection_mock.telemetry.assert_called_once() + call_args = connection_mock.telemetry.call_args.args + assert call_args[0] == TelemetryAPI.DRIVER