diff --git a/bin/make-unasync b/bin/make-unasync index f1b99465f..675a6bc4c 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -154,12 +154,12 @@ class CustomRule(unasync.Rule): start += 1 end += 1 else: - out += self._unasync_prefix(name[start:(end - 1)]) + out += self._unasync_name(name[start:(end - 1)]) start = end - 1 sub_name = name[start:] if sub_name.isidentifier(): - out += self._unasync_prefix(name[start:]) + out += self._unasync_name(name[start:]) else: out += sub_name @@ -221,6 +221,7 @@ def apply_unasync(files): "mark_async_test": "mark_sync_test", "assert_awaited_once": "assert_called_once", "assert_awaited_once_with": "assert_called_once_with", + "await_count": "call_count", } additional_testkit_backend_replacements = {} rules = [ diff --git a/docs/source/api.rst b/docs/source/api.rst index f99b2909f..fc19dbdd8 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -14,40 +14,41 @@ Driver Construction The :class:`neo4j.Driver` construction is done via a ``classmethod`` on the :class:`neo4j.GraphDatabase` class. .. autoclass:: neo4j.GraphDatabase - :members: driver + :members: bookmark_manager + .. method:: driver -Driver creation example: + Driver creation example: -.. code-block:: python + .. code-block:: python - from neo4j import GraphDatabase + from neo4j import GraphDatabase - uri = "neo4j://example.com:7687" - driver = GraphDatabase.driver(uri, auth=("neo4j", "password")) + uri = "neo4j://example.com:7687" + driver = GraphDatabase.driver(uri, auth=("neo4j", "password")) - driver.close() # close the driver object + driver.close() # close the driver object -For basic authentication, ``auth`` can be a simple tuple, for example: + For basic authentication, ``auth`` can be a simple tuple, for example: -.. code-block:: python + .. code-block:: python - auth = ("neo4j", "password") + auth = ("neo4j", "password") -This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``. -Other authentication methods are described under :ref:`auth-ref`. + This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``. + Other authentication methods are described under :ref:`auth-ref`. -``with`` block context example: + ``with`` block context example: -.. code-block:: python + .. code-block:: python - from neo4j import GraphDatabase + from neo4j import GraphDatabase - uri = "neo4j://example.com:7687" - with GraphDatabase.driver(uri, auth=("neo4j", "password")) as driver: - # use the driver + uri = "neo4j://example.com:7687" + with GraphDatabase.driver(uri, auth=("neo4j", "password")) as driver: + # use the driver @@ -138,7 +139,7 @@ Alternatively, one of the auth token helper functions can be used. Driver ****** -Every Neo4j-backed application will require a :class:`neo4j.Driver` object. +Every Neo4j-backed application will require a driver object. This object holds the details required to establish connections with a Neo4j database, including server URIs, credentials and other configuration. :class:`neo4j.Driver` objects hold a connection pool from which :class:`neo4j.Session` objects can borrow connections. @@ -274,6 +275,7 @@ Specify whether TCP keep-alive should be enabled. :Default: ``True`` **This is experimental.** (See :ref:`filter-warnings-ref`) +It might be changed or removed any time even without prior notice. .. _max-connection-lifetime-ref: @@ -576,6 +578,7 @@ To construct a :class:`neo4j.Session` use the :meth:`neo4j.Driver.session` metho + :ref:`database-ref` + :ref:`default-access-mode-ref` + :ref:`fetch-size-ref` ++ :ref:`bookmark-manager-ref` .. _bookmarks-ref: @@ -704,6 +707,33 @@ The fetch size used for requesting messages from Neo4j. :Default: ``1000`` +.. _bookmark-manager-ref: + +``bookmark_manager`` +-------------------- +Specify a bookmark manager for the session to use. If present, the bookmark +manager is used to keep all work within the session causally consistent with +all work in other sessions using the same bookmark manager. + +See :class:`.BookmarkManager` for more information. + +.. warning:: + Enabling the BookmarkManager can have a negative impact on performance since + all queries will wait for the latest changes to be propagated across the + cluster. + + For simple use-cases, it often suffices that work within a single session + is automatically causally consistent. + +:Type: :const:`None` or :class:`.BookmarkManager` +:Default: :const:`None` + +.. versionadded:: 5.0 + +**This is experimental.** (See :ref:`filter-warnings-ref`) +It might be changed or removed any time even without prior notice. + + *********** @@ -829,7 +859,7 @@ Returning a live result object would prevent the driver from correctly managing This function will receive a :class:`neo4j.ManagedTransaction` object as its first parameter. -.. autoclass:: neo4j.ManagedTransaction +.. autoclass:: neo4j.ManagedTransaction() .. automethod:: run @@ -911,6 +941,7 @@ Graph .. automethod:: relationship_type **This is experimental.** (See :ref:`filter-warnings-ref`) +It might be changed or removed any time even without prior notice. ****** @@ -1231,6 +1262,14 @@ Temporal Data Types See topic :ref:`temporal-data-types` for more details. +*************** +BookmarkManager +*************** + +.. autoclass:: neo4j.api.BookmarkManager + :members: + + .. _errors-ref: ****** @@ -1526,6 +1565,7 @@ Bookmarks .. autoclass:: neo4j.Bookmarks :members: + :special-members: __bool__, __add__, __iter__ .. autoclass:: neo4j.Bookmark :members: diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index c9bdd4985..3bf4c34ac 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -21,51 +21,53 @@ Async Driver Construction The :class:`neo4j.AsyncDriver` construction is done via a ``classmethod`` on the :class:`neo4j.AsyncGraphDatabase` class. .. autoclass:: neo4j.AsyncGraphDatabase - :members: driver + :members: bookmark_manager + .. automethod:: driver -Driver creation example: + Driver creation example: -.. code-block:: python + .. code-block:: python - import asyncio + import asyncio - from neo4j import AsyncGraphDatabase + from neo4j import AsyncGraphDatabase - async def main(): - uri = "neo4j://example.com:7687" - driver = AsyncGraphDatabase.driver(uri, auth=("neo4j", "password")) + async def main(): + uri = "neo4j://example.com:7687" + driver = AsyncGraphDatabase.driver(uri, auth=("neo4j", "password")) - await driver.close() # close the driver object + await driver.close() # close the driver object - asyncio.run(main()) + asyncio.run(main()) -For basic authentication, ``auth`` can be a simple tuple, for example: + For basic authentication, ``auth`` can be a simple tuple, for example: -.. code-block:: python + .. code-block:: python - auth = ("neo4j", "password") + auth = ("neo4j", "password") -This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``. -Other authentication methods are described under :ref:`auth-ref`. + This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``. + Other authentication methods are described under :ref:`auth-ref`. -``with`` block context example: + ``with`` block context example: -.. code-block:: python + .. code-block:: python - import asyncio + import asyncio - from neo4j import AsyncGraphDatabase + from neo4j import AsyncGraphDatabase - async def main(): - uri = "neo4j://example.com:7687" - auth = ("neo4j", "password") - async with AsyncGraphDatabase.driver(uri, auth=auth) as driver: - # use the driver - ... + async def main(): + uri = "neo4j://example.com:7687" + auth = ("neo4j", "password") + async with AsyncGraphDatabase.driver(uri, auth=auth) as driver: + # use the driver + ... + + asyncio.run(main()) - asyncio.run(main()) .. _async-uri-ref: @@ -120,7 +122,7 @@ Each supported scheme maps to a particular :class:`neo4j.AsyncDriver` subclass t AsyncDriver *********** -Every Neo4j-backed application will require a :class:`neo4j.AsyncDriver` object. +Every Neo4j-backed application will require a driver object. This object holds the details required to establish connections with a Neo4j database, including server URIs, credentials and other configuration. :class:`neo4j.AsyncDriver` objects hold a connection pool from which :class:`neo4j.AsyncSession` objects can borrow connections. @@ -144,6 +146,7 @@ Async Driver Configuration (see :ref:`driver-configuration-ref`). The only difference is that the async driver accepts an async custom resolver function: + .. _async-resolver-ref: ``resolver`` @@ -366,7 +369,35 @@ Session Configuration ===================== :class:`neo4j.AsyncSession` is configured exactly like :class:`neo4j.Session` -(see :ref:`session-configuration-ref`). +(see :ref:`session-configuration-ref`). The only difference is the async session +accepts either a :class:`neo4j.api.BookmarkManager` object or a +:class:`neo4j.api.AsyncBookmarkManager` as bookmark manager: + + +.. _async-bookmark-manager-ref: + +``bookmark_manager`` +-------------------- +Specify a bookmark manager for the driver to use. If present, the bookmark +manager is used to keep all work on the driver causally consistent. + +See :class:`BookmarkManager` for more information. + +.. warning:: + Enabling the BookmarkManager can have a negative impact on performance since + all queries will wait for the latest changes to be propagated across the + cluster. + + For simpler use-cases, sessions (:class:`.AsyncSession`) can be used to + group a series of queries together that will be causally chained + automatically. + +:Type: :const:`None`, :class:`BookmarkManager`, or :class:`AsyncBookmarkManager` +:Default: :const:`None` + +**This is experimental.** (See :ref:`filter-warnings-ref`) +It might be changed or removed any time even without prior notice. + **************** @@ -501,7 +532,7 @@ Returning a live result object would prevent the driver from correctly managing This function will receive a :class:`neo4j.AsyncManagedTransaction` object as its first parameter. -.. autoclass:: neo4j.AsyncManagedTransaction +.. autoclass:: neo4j.AsyncManagedTransaction() .. automethod:: run @@ -522,7 +553,6 @@ Example: To exert more control over how a transaction function is carried out, the :func:`neo4j.unit_of_work` decorator can be used. - *********** AsyncResult *********** @@ -568,6 +598,13 @@ A :class:`neo4j.AsyncResult` is attached to an active connection, through a :cla See https://neo4j.com/docs/python-manual/current/cypher-workflow/#python-driver-type-mapping for more about type mapping. +******************** +AsyncBookmarkManager +******************** + +.. autoclass:: neo4j.api.AsyncBookmarkManager + :members: + ****************** Async Cancellation diff --git a/neo4j/_async/bookmark_manager.py b/neo4j/_async/bookmark_manager.py new file mode 100644 index 000000000..41c098cc7 --- /dev/null +++ b/neo4j/_async/bookmark_manager.py @@ -0,0 +1,106 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import typing as t +from collections import defaultdict + +from .._async_compat.concurrency import AsyncCooperativeLock +from .._async_compat.util import AsyncUtil +from ..api import ( + AsyncBookmarkManager, + Bookmarks, +) + + +T_BmSupplier = t.Callable[[t.Optional[str]], + t.Union[Bookmarks, t.Awaitable[Bookmarks]]] +T_BmConsumer = t.Callable[[str, Bookmarks], t.Union[None, t.Awaitable[None]]] + + +def _bookmarks_to_set( + bookmarks: t.Union[Bookmarks, t.Iterable[str]] +) -> t.Set[str]: + if isinstance(bookmarks, Bookmarks): + return set(bookmarks.raw_values) + return set(map(str, bookmarks)) + + +class AsyncNeo4jBookmarkManager(AsyncBookmarkManager): + def __init__( + self, + initial_bookmarks: t.Mapping[str, t.Union[Bookmarks, + t.Iterable[str]]] = None, + bookmarks_supplier: T_BmSupplier = None, + bookmarks_consumer: T_BmConsumer = None + ) -> None: + super().__init__() + self._bookmarks_supplier = bookmarks_supplier + self._bookmarks_consumer = bookmarks_consumer + if initial_bookmarks is None: + initial_bookmarks = {} + self._bookmarks = defaultdict( + set, ((k, _bookmarks_to_set(v)) + for k, v in initial_bookmarks.items()) + ) + self._lock = AsyncCooperativeLock() + + async def update_bookmarks( + self, database: str, previous_bookmarks: t.Collection[str], + new_bookmarks: t.Collection[str] + ) -> None: + if not new_bookmarks: + return + with self._lock: + curr_bms = self._bookmarks[database] + curr_bms.difference_update(previous_bookmarks) + curr_bms.update(new_bookmarks) + if self._bookmarks_consumer: + curr_bms_snapshot = Bookmarks.from_raw_values(curr_bms) + if self._bookmarks_consumer: + await AsyncUtil.callback( + self._bookmarks_consumer, database, curr_bms_snapshot + ) + + async def get_bookmarks(self, database: str) -> t.Set[str]: + with self._lock: + bms = set(self._bookmarks[database]) + if self._bookmarks_supplier: + extra_bms = await AsyncUtil.callback( + self._bookmarks_supplier, database + ) + bms.update(extra_bms.raw_values) + return bms + + async def get_all_bookmarks(self) -> t.Set[str]: + bms: t.Set[str] = set() + with self._lock: + for database in self._bookmarks.keys(): + bms.update(self._bookmarks[database]) + if self._bookmarks_supplier: + extra_bms = await AsyncUtil.callback( + self._bookmarks_supplier, None + ) + bms.update(extra_bms.raw_values) + return bms + + async def forget(self, databases: t.Iterable[str]) -> None: + with self._lock: + for database in databases: + self._bookmarks.pop(database, None) diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py index 3efe3920f..36d56732d 100644 --- a/neo4j/_async/driver.py +++ b/neo4j/_async/driver.py @@ -43,7 +43,9 @@ ) from ..addressing import Address from ..api import ( + AsyncBookmarkManager, Auth, + BookmarkManager, Bookmarks, DRIVER_BOLT, DRIVER_NEO4J, @@ -62,11 +64,16 @@ URI_SCHEME_NEO4J_SECURE, URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, ) +from .bookmark_manager import ( + AsyncNeo4jBookmarkManager, + T_BmConsumer as _T_BmConsumer, + T_BmSupplier as _T_BmSupplier, +) from .work import AsyncSession class AsyncGraphDatabase: - """Accessor for :class:`neo4j.Driver` construction. + """Accessor for :class:`neo4j.AsyncDriver` construction. """ if t.TYPE_CHECKING: @@ -103,7 +110,9 @@ def driver( retry_delay_jitter_factor: float = ..., database: t.Optional[str] = ..., fetch_size: int = ..., - impersonated_user: t.Optional[str] = ... + impersonated_user: t.Optional[str] = ..., + bookmark_manager: t.Union[AsyncBookmarkManager, + BookmarkManager, None] = ... ) -> AsyncDriver: ... @@ -202,6 +211,81 @@ def driver(cls, uri, *, auth=None, **config) -> AsyncDriver: return cls.neo4j_driver(parsed.netloc, auth=auth, routing_context=routing_context, **config) + @classmethod + @experimental( + "The bookmark manager feature is experimental. " + "It might be changed or removed any time even without prior notice." + ) + def bookmark_manager( + cls, + initial_bookmarks: t.Mapping[str, t.Union[Bookmarks, + t.Iterable[str]]] = None, + bookmarks_supplier: _T_BmSupplier = None, + bookmarks_consumer: _T_BmConsumer = None + ) -> AsyncBookmarkManager: + """Create a :class:`.AsyncBookmarkManager` with default implementation. + + Basic usage example to configure sessions with the built-in bookmark + manager implementation so that all work is automatically causally + chained (i.e., all reads can observe all previous writes even in a + clustered setup):: + + import neo4j + + driver = neo4j.AsyncGraphDatabase.driver(...) + bookmark_manager = neo4j.AsyncBookmarkManager(...) + + async with driver.session( + bookmark_manager=bookmark_manager + ) as session1: + async with driver.session( + bookmark_manager=bookmark_manager + ) as session2: + session1.run("") + # READ_QUERY is guaranteed to see what WRITE_QUERY wrote. + session2.run("") + + This is a very contrived example, and in this particular case, having + both queries in the same session has the exact same effect and might + even be more performant. However, when dealing with sessions spanning + multiple threads, async Tasks, processes, or even hosts, the bookmark + manager can come in handy as sessions are not safe to be used + concurrently. + + :param initial_bookmarks: + The initial set of bookmarks. The returned bookmark manager will + use this to initialize its internal bookmarks per database. + If present, this parameter must be a mapping of database names + to :class:`.Bookmarks` or an iterable of raw bookmark values (str). + :param bookmarks_supplier: + Function which will be called every time the default bookmark + manager's method :meth:`.AsyncBookmarkManager.get_bookmarks` + or :meth:`.AsyncBookmarkManager.get_all_bookmarks` gets called. + The function will be passed the name of the database (``str``) if + ``.get_bookmarks`` is called or ``None`` if ``.get_all_bookmarks`` + is called. The function must return a :class:`.Bookmarks` object. + The result of ``bookmarks_supplier`` will then be concatenated with + the internal set of bookmarks and used to configure the session in + creation. + :param bookmarks_consumer: + Function which will be called whenever the set of bookmarks + handled by the bookmark manager gets updated with the new + internal bookmark set. It will receive the name of the database + and the new set of bookmarks. + + :returns: A default implementation of :class:`AsyncBookmarkManager`. + + **This is experimental.** (See :ref:`filter-warnings-ref`) + It might be changed or removed any time even without prior notice. + + .. versionadded:: 5.0 + """ + return AsyncNeo4jBookmarkManager( + initial_bookmarks=initial_bookmarks, + bookmarks_supplier=bookmarks_supplier, + bookmarks_consumer=bookmarks_consumer + ) + @classmethod def bolt_driver(cls, target, *, auth=None, **config): """ Create a driver for direct Bolt server access that uses @@ -339,13 +423,16 @@ def session( fetch_size: int = ..., impersonated_user: t.Optional[str] = ..., bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., + ignore_bookmark_manager: bool = ..., default_access_mode: str = ..., + bookmark_manager: t.Union[AsyncBookmarkManager, + BookmarkManager, None] = ..., # undocumented/unsupported options # they may be change or removed any time without prior notice initial_retry_delay: float = ..., retry_delay_multiplier: float = ..., - retry_delay_jitter_factor: float = ..., + retry_delay_jitter_factor: float = ... ) -> AsyncSession: ... @@ -382,11 +469,13 @@ async def verify_connectivity( impersonated_user: t.Optional[str] = ..., bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., default_access_mode: str = ..., + bookmark_manager: t.Union[AsyncBookmarkManager, + BookmarkManager, None] = ..., # undocumented/unsupported options initial_retry_delay: float = ..., retry_delay_multiplier: float = ..., - retry_delay_jitter_factor: float = ..., + retry_delay_jitter_factor: float = ... ) -> None: ... @@ -444,11 +533,13 @@ async def get_server_info( impersonated_user: t.Optional[str] = ..., bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., default_access_mode: str = ..., + bookmark_manager: t.Union[AsyncBookmarkManager, + BookmarkManager, None] = ..., # undocumented/unsupported options initial_retry_delay: float = ..., retry_delay_multiplier: float = ..., - retry_delay_jitter_factor: float = ..., + retry_delay_jitter_factor: float = ... ) -> ServerInfo: ... diff --git a/neo4j/_async/io/_pool.py b/neo4j/_async/io/_pool.py index 82ff5ed60..b530b3a2a 100644 --- a/neo4j/_async/io/_pool.py +++ b/neo4j/_async/io/_pool.py @@ -515,9 +515,9 @@ async def fetch_routing_info( cx = await self._acquire(address, deadline, None) try: routing_table = await cx.route( - database or self.workspace_config.database, - imp_user or self.workspace_config.impersonated_user, - bookmarks + database=database or self.workspace_config.database, + imp_user=imp_user or self.workspace_config.impersonated_user, + bookmarks=bookmarks ) finally: await self.release(cx) diff --git a/neo4j/_async/work/result.py b/neo4j/_async/work/result.py index 066687073..b469bd25d 100644 --- a/neo4j/_async/work/result.py +++ b/neo4j/_async/work/result.py @@ -80,6 +80,7 @@ def __init__(self, connection, fetch_size, on_closed, on_error): self._keys = None self._record_buffer = deque() self._summary = None + self._database = None self._bookmark = None self._raw_qid = -1 self._fetch_size = fetch_size @@ -127,7 +128,9 @@ async def _run( "query": query_text, "parameters": parameters, "server": self._connection.server_info, + "database": db, } + self._database = db def on_attached(metadata): self._metadata.update(metadata) @@ -189,6 +192,7 @@ def on_success(summary_metadata): return self._metadata.update(summary_metadata) self._bookmark = summary_metadata.get("bookmark") + self._database = summary_metadata.get("db", self._database) self._connection.pull( n=self._fetch_size, @@ -220,6 +224,7 @@ def on_success(summary_metadata): self._discarding = False self._metadata.update(summary_metadata) self._bookmark = summary_metadata.get("bookmark") + self._database = summary_metadata.get("db", self._database) # This was the last page received, discard the rest self._connection.discard( diff --git a/neo4j/_async/work/session.py b/neo4j/_async/work/session.py index 4b6dce84d..35e9236d5 100644 --- a/neo4j/_async/work/session.py +++ b/neo4j/_async/work/session.py @@ -36,10 +36,7 @@ from ..._async_compat import async_sleep from ..._async_compat.util import AsyncUtil from ..._conf import SessionConfig -from ..._meta import ( - deprecated, - deprecation_warn, -) +from ..._meta import deprecated from ...api import ( Bookmarks, READ_ACCESS, @@ -100,9 +97,11 @@ class AsyncSession(AsyncWorkspace): _state_failed = False def __init__(self, pool, session_config): - super().__init__(pool, session_config) assert isinstance(session_config, SessionConfig) - self._bookmarks = self._prepare_bookmarks(session_config.bookmarks) + super().__init__(pool, session_config) + self._initialize_bookmarks(session_config.bookmarks) + if not session_config.ignore_bookmark_manager: + self._bookmark_manager = session_config.bookmark_manager async def __aenter__(self) -> AsyncSession: return self @@ -116,21 +115,6 @@ async def __aexit__(self, exception_type, exception_value, traceback): self._state_failed = True await self.close() - def _prepare_bookmarks(self, bookmarks): - if isinstance(bookmarks, Bookmarks): - return tuple(bookmarks.raw_values) - if hasattr(bookmarks, "__iter__"): - deprecation_warn( - "Passing an iterable as `bookmarks` to `Session` is " - "deprecated. Please use a `Bookmarks` instance.", - stack_level=5 - ) - return tuple(bookmarks) - if not bookmarks: - return () - raise TypeError("Bookmarks must be an instance of Bookmarks or an " - "iterable of raw bookmarks (deprecated).") - async def _connect(self, access_mode, **access_kwargs): if access_mode is None: access_mode = self._config.default_access_mode @@ -147,10 +131,6 @@ async def _disconnect(self, sync=False): self._handle_cancellation(message="_disconnect") raise - def _collect_bookmark(self, bookmark): - if bookmark: - self._bookmarks = bookmark, - def _handle_cancellation(self, message="General"): self._transaction = None self._auto_result = None @@ -165,7 +145,8 @@ def _handle_cancellation(self, message="General"): async def _result_closed(self): if self._auto_result: - self._collect_bookmark(self._auto_result._bookmark) + await self._update_bookmark(self._auto_result._database, + self._auto_result._bookmark) self._auto_result = None await self._disconnect() @@ -196,7 +177,10 @@ async def close(self) -> None: if self._state_failed is False: try: await self._auto_result.consume() - self._collect_bookmark(self._auto_result._bookmark) + await self._update_bookmark( + self._auto_result._database, + self._auto_result._bookmark + ) except Exception as error: # TODO: Investigate potential non graceful close states self._auto_result = None @@ -302,10 +286,11 @@ async def run( cx, self._config.fetch_size, self._result_closed, self._result_error ) + bookmarks = await self._get_all_bookmarks() await self._auto_result._run( query, parameters, self._config.database, self._config.impersonated_user, self._config.default_access_mode, - self._bookmarks, **kwargs + bookmarks, **kwargs ) return self._auto_result @@ -336,7 +321,8 @@ async def last_bookmark(self) -> t.Optional[str]: await self._auto_result.consume() if self._transaction and self._transaction._closed: - self._collect_bookmark(self._transaction._bookmark) + await self._update_bookmark(self._transaction._database, + self._transaction._bookmark) self._transaction = None if self._bookmarks: @@ -377,14 +363,16 @@ async def last_bookmarks(self) -> Bookmarks: await self._auto_result.consume() if self._transaction and self._transaction._closed(): - self._collect_bookmark(self._transaction._bookmark) + await self._update_bookmark(self._transaction._database, + self._transaction._bookmark) self._transaction = None return Bookmarks.from_raw_values(self._bookmarks) async def _transaction_closed_handler(self): if self._transaction: - self._collect_bookmark(self._transaction._bookmark) + await self._update_bookmark(self._transaction._database, + self._transaction._bookmark) self._transaction = None await self._disconnect() @@ -408,9 +396,10 @@ async def _open_transaction( self._transaction_error_handler, self._transaction_cancel_handler ) + bookmarks = await self._get_all_bookmarks() await self._transaction._begin( self._config.database, self._config.impersonated_user, - self._bookmarks, access_mode, metadata, timeout + bookmarks, access_mode, metadata, timeout ) async def begin_transaction( diff --git a/neo4j/_async/work/transaction.py b/neo4j/_async/work/transaction.py index 11a496ac9..eee74c86c 100644 --- a/neo4j/_async/work/transaction.py +++ b/neo4j/_async/work/transaction.py @@ -44,6 +44,7 @@ def __init__(self, connection, fetch_size, on_closed, on_error, connection, self._error_handler ) self._bookmark = None + self._database = None self._results = [] self._closed_flag = False self._last_error = None @@ -69,6 +70,7 @@ async def _exit(self, exception_type, exception_value, traceback): async def _begin( self, database, imp_user, bookmarks, access_mode, metadata, timeout ): + self._database = database self._connection.begin( bookmarks=bookmarks, metadata=metadata, timeout=timeout, mode=access_mode, db=database, imp_user=imp_user @@ -168,6 +170,7 @@ async def _commit(self): await self._connection.send_all() await self._connection.fetch_all() self._bookmark = metadata.get("bookmark") + self._database = metadata.get("db", self._database) except asyncio.CancelledError: self._on_cancel() raise diff --git a/neo4j/_async/work/workspace.py b/neo4j/_async/work/workspace.py index cd6e17b46..45f9e5dd5 100644 --- a/neo4j/_async/work/workspace.py +++ b/neo4j/_async/work/workspace.py @@ -20,12 +20,14 @@ import asyncio +from ..._async_compat.util import AsyncUtil from ..._conf import WorkspaceConfig from ..._deadline import Deadline from ..._meta import ( deprecation_warn, unclosed_resource_warn, ) +from ...api import Bookmarks from ...exceptions import ( ServiceUnavailable, SessionError, @@ -44,7 +46,10 @@ def __init__(self, pool, config): self._connection_access_mode = None # Sessions are supposed to cache the database on which to operate. self._cached_database = False - self._bookmarks = None + self._bookmarks = () + self._initial_bookmarks = () + self._bookmark_manager = None + self._last_from_bookmark_manager = None # Workspace has been closed. self._closed = False @@ -77,6 +82,81 @@ def _set_cached_database(self, database): self._cached_database = True self._config.database = database + def _initialize_bookmarks(self, bookmarks): + if isinstance(bookmarks, Bookmarks): + prepared_bookmarks = tuple(bookmarks.raw_values) + elif hasattr(bookmarks, "__iter__"): + deprecation_warn( + "Passing an iterable as `bookmarks` to `Session` is " + "deprecated. Please use a `Bookmarks` instance.", + stack_level=5 + ) + prepared_bookmarks = tuple(bookmarks) + elif not bookmarks: + prepared_bookmarks = () + else: + raise TypeError("Bookmarks must be an instance of Bookmarks or an " + "iterable of raw bookmarks (deprecated).") + self._initial_bookmarks = self._bookmarks = prepared_bookmarks + + async def _get_bookmarks(self, database): + if self._bookmark_manager is None: + return self._bookmarks + + # For 4.3- support: the server will not send the resolved home + # database back. To avoid confusion between `None` as in "all + # database" and `None` as in "home database" we re-write the + # home database to `""`, which otherwise is an invalid database + # name. It will not work properly either way, as the home database + # can change (server config change or client side user change). + if database is None: + database = "" + self._last_from_bookmark_manager = tuple({ + *await AsyncUtil.callback( + self._bookmark_manager.get_bookmarks, database + ), + *self._initial_bookmarks + }) + return self._last_from_bookmark_manager + + async def _get_all_bookmarks(self): + if self._bookmark_manager is None: + return self._bookmarks + + self._last_from_bookmark_manager = tuple({ + *await AsyncUtil.callback( + self._bookmark_manager.get_all_bookmarks, + ), + *self._initial_bookmarks + }) + return self._last_from_bookmark_manager + + async def _update_bookmarks(self, database, new_bookmarks): + if not new_bookmarks: + return + self._initial_bookmarks = () + self._bookmarks = new_bookmarks + if self._bookmark_manager is None: + return + previous_bookmarks = self._last_from_bookmark_manager + # For 4.3- support: the server will not send the resolved home + # database back. To avoid confusion between `None` as in "all + # database" and `None` as in "home database" we re-write the home + # database to `""`, which otherwise is an invalid database name. + if database is None: + database = "" + await AsyncUtil.callback( + self._bookmark_manager.update_bookmarks, + database, previous_bookmarks, new_bookmarks + ) + + async def _update_bookmark(self, database, bookmark): + if not bookmark: + return + if not database: + database = self._config.database + await self._update_bookmarks(database, (bookmark,)) + async def _connect(self, access_mode, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout if self._connection: @@ -99,7 +179,7 @@ async def _connect(self, access_mode, **acquire_kwargs): await self._pool.update_routing_table( database=self._config.database, imp_user=self._config.impersonated_user, - bookmarks=self._bookmarks, + bookmarks=await self._get_bookmarks("system"), acquisition_timeout=acquisition_timeout, database_callback=self._set_cached_database ) @@ -107,7 +187,7 @@ async def _connect(self, access_mode, **acquire_kwargs): "access_mode": access_mode, "timeout": acquisition_timeout, "database": self._config.database, - "bookmarks": self._bookmarks, + "bookmarks": await self._get_bookmarks("system"), "liveness_check_timeout": None, } acquire_kwargs_.update(acquire_kwargs) diff --git a/neo4j/_async_compat/concurrency.py b/neo4j/_async_compat/concurrency.py index 3f149a668..6c6145438 100644 --- a/neo4j/_async_compat/concurrency.py +++ b/neo4j/_async_compat/concurrency.py @@ -16,10 +16,17 @@ # limitations under the License. +from __future__ import annotations + import asyncio import collections import re import threading +import typing as t + + +if t.TYPE_CHECKING: + import typing_extensions as te from neo4j._async_compat.shims import wait_for @@ -211,6 +218,12 @@ def release(self): def __exit__(self, t, v, tb): self.release() + async def __aenter__(self): + return self.__enter__() + + async def __aexit__(self, t, v, tb): + self.__exit__(t, v, tb) + class AsyncCooperativeRLock: """Reentrant lock placeholder for cooperative asyncio Python. @@ -423,6 +436,8 @@ def notify_all(self): self.notify(len(self._waiters)) -Condition = threading.Condition -CooperativeLock = Lock = threading.Lock -CooperativeRLock = RLock = threading.RLock +Condition: te.TypeAlias = threading.Condition +CooperativeLock: te.TypeAlias = threading.Lock +Lock: te.TypeAlias = threading.Lock +CooperativeRLock: te.TypeAlias = threading.RLock +RLock: te.TypeAlias = threading.RLock diff --git a/neo4j/_async_compat/util.py b/neo4j/_async_compat/util.py index aaecf02e3..53771edf2 100644 --- a/neo4j/_async_compat/util.py +++ b/neo4j/_async_compat/util.py @@ -16,13 +16,23 @@ # limitations under the License. +from __future__ import annotations + import asyncio import inspect +import typing as t from functools import wraps from .._meta import experimental +if t.TYPE_CHECKING: + import typing_extensions as te + + _T = t.TypeVar("_T") + _P = te.ParamSpec("_P") + + __all__ = [ "AsyncUtil", "Util", @@ -43,6 +53,23 @@ async def next(it): async def list(it): return [x async for x in it] + @staticmethod + @t.overload + async def callback(cb: None, *args: object, **kwargs: object) -> None: + ... + + @staticmethod + @t.overload + async def callback( + cb: t.Union[ + t.Callable[_P, t.Union[_T, t.Awaitable[_T]]], + t.Callable[_P, t.Awaitable[_T]], + t.Callable[_P, _T], + ], + *args: _P.args, **kwargs: _P.kwargs + ) -> _T: + ... + @staticmethod async def callback(cb, *args, **kwargs): if callable(cb): @@ -61,13 +88,24 @@ async def shielded_function(*args, **kwargs): return shielded_function - is_async_code = True + is_async_code: t.ClassVar = True class Util: - iter = iter - next = next - list = list + iter: t.ClassVar = iter + next: t.ClassVar = next + list: t.ClassVar = list + + @staticmethod + @t.overload + def callback(cb: None, *args: object, **kwargs: object) -> None: + ... + + @staticmethod + @t.overload + def callback(cb: t.Callable[_P, _T], + *args: _P.args, **kwargs: _P.kwargs) -> _T: + ... @staticmethod def callback(cb, *args, **kwargs): @@ -78,4 +116,4 @@ def callback(cb, *args, **kwargs): def shielded(coro_function): return coro_function - is_async_code = False + is_async_code: t.ClassVar = False diff --git a/neo4j/_conf.py b/neo4j/_conf.py index 79cd95d8d..434670853 100644 --- a/neo4j/_conf.py +++ b/neo4j/_conf.py @@ -16,11 +16,14 @@ # limitations under the License. +import warnings from abc import ABCMeta from collections.abc import Mapping from ._meta import ( deprecation_warn, + experimental_warn, + ExperimentalWarning, get_user_agent, ) from .api import ( @@ -132,28 +135,59 @@ def __init__(self, new, converter=None): self.converter = converter +class DeprecatedOption: + """Used for deprecated config options without alternative.""" + + def __init__(self, value): + self.value = value + + +class ExperimentalOption: + """Used for experimental config options.""" + + def __init__(self, value): + self.value = value + + class ConfigType(ABCMeta): def __new__(mcs, name, bases, attributes): fields = [] deprecated_aliases = {} deprecated_alternatives = {} + deprecated_options = {} + experimental_options = {} for base in bases: if type(base) is mcs: fields += base.keys() deprecated_aliases.update(base._deprecated_aliases()) deprecated_alternatives.update(base._deprecated_alternatives()) + deprecated_options.update(base._deprecated_options()) + experimental_options.update(base._experimental_options()) for k, v in attributes.items(): + if ( + k.startswith("_") + or callable(v) + or isinstance(v, (staticmethod, classmethod)) + ): + continue if isinstance(v, DeprecatedAlias): deprecated_aliases[k] = v.new - elif isinstance(v, DeprecatedAlternative): + continue + if isinstance(v, DeprecatedAlternative): deprecated_alternatives[k] = v.new, v.converter - elif not (k.startswith("_") - or callable(v) - or isinstance(v, (staticmethod, classmethod))): - fields.append(k) + continue + fields.append(k) + if isinstance(v, DeprecatedOption): + deprecated_options[k] = v.value + attributes[k] = v.value + continue + if isinstance(v, ExperimentalOption): + experimental_options[k] = v.value + attributes[k] = v.value + continue def keys(_): return set(fields) @@ -173,6 +207,12 @@ def _deprecated_aliases(_): def _deprecated_alternatives(_): return deprecated_alternatives + def _deprecated_options(_): + return deprecated_options + + def _experimental_options(_): + return experimental_options + attributes.setdefault("keys", classmethod(keys)) attributes.setdefault("_get_new", classmethod(_get_new)) @@ -182,6 +222,10 @@ def _deprecated_alternatives(_): classmethod(_deprecated_aliases)) attributes.setdefault("_deprecated_alternatives", classmethod(_deprecated_alternatives)) + attributes.setdefault("_deprecated_options", + classmethod(_deprecated_options)) + attributes.setdefault("_experimental_options", + classmethod(_experimental_options)) return super(ConfigType, mcs).__new__( mcs, name, bases, {k: v for k, v in attributes.items() @@ -227,6 +271,15 @@ def __update(self, data): def set_attr(k, v): if k in self.keys(): + if k in self._deprecated_options(): + deprecation_warn("The '{}' config key is " + "deprecated.".format(k)) + if k in self._experimental_options(): + experimental_warn( + "The '{}' config key is experimental. " + "It might be changed or removed any time even without " + "prior notice.".format(k) + ) setattr(self, k, v) elif k in self._deprecated_keys(): k0 = self._get_new(k) @@ -253,7 +306,13 @@ def set_attr(k, v): def __init__(self, *args, **kwargs): for arg in args: - self.__update(arg) + if isinstance(arg, Config): + with warnings.catch_warnings(): + for cat in (DeprecationWarning, ExperimentalWarning): + warnings.filterwarnings("ignore", category=cat) + self.__update(arg) + else: + self.__update(arg) self.__update(kwargs) def __repr__(self): @@ -416,6 +475,10 @@ class WorkspaceConfig(Config): impersonated_user = None # Note that you need appropriate permissions to do so. + #: Bookmark Manager + bookmark_manager = ExperimentalOption(None) + # Specify the bookmark manager to be used for sessions by default. + class SessionConfig(WorkspaceConfig): """ Session configuration. @@ -427,6 +490,9 @@ class SessionConfig(WorkspaceConfig): #: Default AccessMode default_access_mode = WRITE_ACCESS + #: Whether to ignore the bookmark manager configured at driver level + ignore_bookmark_manager = False + class TransactionConfig(Config): """ Transaction configuration. This is internal for now. diff --git a/neo4j/_meta.py b/neo4j/_meta.py index 382f40507..3e41e3467 100644 --- a/neo4j/_meta.py +++ b/neo4j/_meta.py @@ -17,10 +17,14 @@ import asyncio +import typing as t from functools import wraps from warnings import warn +_FuncT = t.TypeVar("_FuncT", bound=t.Callable) + + # Can be automatically overridden in builds package = "neo4j" version = "5.0.dev0" @@ -39,22 +43,19 @@ def get_user_agent(): return template.format(*fields) -def deprecation_warn(message, stack_level=1): - warn(message, category=DeprecationWarning, stacklevel=stack_level + 1) +def _id(x): + return x -from typing import ( - Callable, - cast, - TypeVar, -) +def copy_signature(_: _FuncT) -> t.Callable[[t.Callable], _FuncT]: + return _id -T = TypeVar("T") -FuncT = TypeVar("FuncT", bound=Callable[..., object]) +def deprecation_warn(message, stack_level=1): + warn(message, category=DeprecationWarning, stacklevel=stack_level + 1) -def deprecated(message: str) -> Callable[[FuncT], FuncT]: +def deprecated(message: str) -> t.Callable[[_FuncT], _FuncT]: """ Decorator for deprecating functions and methods. :: @@ -64,21 +65,21 @@ def foo(x): pass """ - def decorator(f: FuncT) -> FuncT: + def decorator(f): if asyncio.iscoroutinefunction(f): @wraps(f) async def inner(*args, **kwargs): deprecation_warn(message, stack_level=2) return await f(*args, **kwargs) - return cast(FuncT, inner) + return inner else: @wraps(f) def inner(*args, **kwargs): deprecation_warn(message, stack_level=2) return f(*args, **kwargs) - return cast(FuncT, inner) + return inner return decorator @@ -86,7 +87,7 @@ def inner(*args, **kwargs): def deprecated_property(message: str): def decorator(f): return property(deprecated(message)(f)) - return cast(property, decorator) + return t.cast(property, decorator) class ExperimentalWarning(Warning): @@ -98,7 +99,7 @@ def experimental_warn(message, stack_level=1): warn(message, category=ExperimentalWarning, stacklevel=stack_level + 1) -def experimental(message): +def experimental(message) -> t.Callable[[_FuncT], _FuncT]: """ Decorator for tagging experimental functions and methods. :: diff --git a/neo4j/_sync/bookmark_manager.py b/neo4j/_sync/bookmark_manager.py new file mode 100644 index 000000000..85cb13bfe --- /dev/null +++ b/neo4j/_sync/bookmark_manager.py @@ -0,0 +1,106 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import typing as t +from collections import defaultdict + +from .._async_compat.concurrency import CooperativeLock +from .._async_compat.util import Util +from ..api import ( + BookmarkManager, + Bookmarks, +) + + +T_BmSupplier = t.Callable[[t.Optional[str]], + t.Union[Bookmarks, t.Union[Bookmarks]]] +T_BmConsumer = t.Callable[[str, Bookmarks], t.Union[None, t.Union[None]]] + + +def _bookmarks_to_set( + bookmarks: t.Union[Bookmarks, t.Iterable[str]] +) -> t.Set[str]: + if isinstance(bookmarks, Bookmarks): + return set(bookmarks.raw_values) + return set(map(str, bookmarks)) + + +class Neo4jBookmarkManager(BookmarkManager): + def __init__( + self, + initial_bookmarks: t.Mapping[str, t.Union[Bookmarks, + t.Iterable[str]]] = None, + bookmarks_supplier: T_BmSupplier = None, + bookmarks_consumer: T_BmConsumer = None + ) -> None: + super().__init__() + self._bookmarks_supplier = bookmarks_supplier + self._bookmarks_consumer = bookmarks_consumer + if initial_bookmarks is None: + initial_bookmarks = {} + self._bookmarks = defaultdict( + set, ((k, _bookmarks_to_set(v)) + for k, v in initial_bookmarks.items()) + ) + self._lock = CooperativeLock() + + def update_bookmarks( + self, database: str, previous_bookmarks: t.Collection[str], + new_bookmarks: t.Collection[str] + ) -> None: + if not new_bookmarks: + return + with self._lock: + curr_bms = self._bookmarks[database] + curr_bms.difference_update(previous_bookmarks) + curr_bms.update(new_bookmarks) + if self._bookmarks_consumer: + curr_bms_snapshot = Bookmarks.from_raw_values(curr_bms) + if self._bookmarks_consumer: + Util.callback( + self._bookmarks_consumer, database, curr_bms_snapshot + ) + + def get_bookmarks(self, database: str) -> t.Set[str]: + with self._lock: + bms = set(self._bookmarks[database]) + if self._bookmarks_supplier: + extra_bms = Util.callback( + self._bookmarks_supplier, database + ) + bms.update(extra_bms.raw_values) + return bms + + def get_all_bookmarks(self) -> t.Set[str]: + bms: t.Set[str] = set() + with self._lock: + for database in self._bookmarks.keys(): + bms.update(self._bookmarks[database]) + if self._bookmarks_supplier: + extra_bms = Util.callback( + self._bookmarks_supplier, None + ) + bms.update(extra_bms.raw_values) + return bms + + def forget(self, databases: t.Iterable[str]) -> None: + with self._lock: + for database in databases: + self._bookmarks.pop(database, None) diff --git a/neo4j/_sync/driver.py b/neo4j/_sync/driver.py index e6096b3d7..058928634 100644 --- a/neo4j/_sync/driver.py +++ b/neo4j/_sync/driver.py @@ -44,6 +44,7 @@ from ..addressing import Address from ..api import ( Auth, + BookmarkManager, Bookmarks, DRIVER_BOLT, DRIVER_NEO4J, @@ -62,6 +63,11 @@ URI_SCHEME_NEO4J_SECURE, URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, ) +from .bookmark_manager import ( + Neo4jBookmarkManager, + T_BmConsumer as _T_BmConsumer, + T_BmSupplier as _T_BmSupplier, +) from .work import Session @@ -103,7 +109,9 @@ def driver( retry_delay_jitter_factor: float = ..., database: t.Optional[str] = ..., fetch_size: int = ..., - impersonated_user: t.Optional[str] = ... + impersonated_user: t.Optional[str] = ..., + bookmark_manager: t.Union[BookmarkManager, + BookmarkManager, None] = ... ) -> Driver: ... @@ -202,6 +210,81 @@ def driver(cls, uri, *, auth=None, **config) -> Driver: return cls.neo4j_driver(parsed.netloc, auth=auth, routing_context=routing_context, **config) + @classmethod + @experimental( + "The bookmark manager feature is experimental. " + "It might be changed or removed any time even without prior notice." + ) + def bookmark_manager( + cls, + initial_bookmarks: t.Mapping[str, t.Union[Bookmarks, + t.Iterable[str]]] = None, + bookmarks_supplier: _T_BmSupplier = None, + bookmarks_consumer: _T_BmConsumer = None + ) -> BookmarkManager: + """Create a :class:`.BookmarkManager` with default implementation. + + Basic usage example to configure sessions with the built-in bookmark + manager implementation so that all work is automatically causally + chained (i.e., all reads can observe all previous writes even in a + clustered setup):: + + import neo4j + + driver = neo4j.GraphDatabase.driver(...) + bookmark_manager = neo4j.BookmarkManager(...) + + with driver.session( + bookmark_manager=bookmark_manager + ) as session1: + with driver.session( + bookmark_manager=bookmark_manager + ) as session2: + session1.run("") + # READ_QUERY is guaranteed to see what WRITE_QUERY wrote. + session2.run("") + + This is a very contrived example, and in this particular case, having + both queries in the same session has the exact same effect and might + even be more performant. However, when dealing with sessions spanning + multiple threads, Tasks, processes, or even hosts, the bookmark + manager can come in handy as sessions are not safe to be used + concurrently. + + :param initial_bookmarks: + The initial set of bookmarks. The returned bookmark manager will + use this to initialize its internal bookmarks per database. + If present, this parameter must be a mapping of database names + to :class:`.Bookmarks` or an iterable of raw bookmark values (str). + :param bookmarks_supplier: + Function which will be called every time the default bookmark + manager's method :meth:`.BookmarkManager.get_bookmarks` + or :meth:`.BookmarkManager.get_all_bookmarks` gets called. + The function will be passed the name of the database (``str``) if + ``.get_bookmarks`` is called or ``None`` if ``.get_all_bookmarks`` + is called. The function must return a :class:`.Bookmarks` object. + The result of ``bookmarks_supplier`` will then be concatenated with + the internal set of bookmarks and used to configure the session in + creation. + :param bookmarks_consumer: + Function which will be called whenever the set of bookmarks + handled by the bookmark manager gets updated with the new + internal bookmark set. It will receive the name of the database + and the new set of bookmarks. + + :returns: A default implementation of :class:`BookmarkManager`. + + **This is experimental.** (See :ref:`filter-warnings-ref`) + It might be changed or removed any time even without prior notice. + + .. versionadded:: 5.0 + """ + return Neo4jBookmarkManager( + initial_bookmarks=initial_bookmarks, + bookmarks_supplier=bookmarks_supplier, + bookmarks_consumer=bookmarks_consumer + ) + @classmethod def bolt_driver(cls, target, *, auth=None, **config): """ Create a driver for direct Bolt server access that uses @@ -339,13 +422,16 @@ def session( fetch_size: int = ..., impersonated_user: t.Optional[str] = ..., bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., + ignore_bookmark_manager: bool = ..., default_access_mode: str = ..., + bookmark_manager: t.Union[BookmarkManager, + BookmarkManager, None] = ..., # undocumented/unsupported options # they may be change or removed any time without prior notice initial_retry_delay: float = ..., retry_delay_multiplier: float = ..., - retry_delay_jitter_factor: float = ..., + retry_delay_jitter_factor: float = ... ) -> Session: ... @@ -382,11 +468,13 @@ def verify_connectivity( impersonated_user: t.Optional[str] = ..., bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., default_access_mode: str = ..., + bookmark_manager: t.Union[BookmarkManager, + BookmarkManager, None] = ..., # undocumented/unsupported options initial_retry_delay: float = ..., retry_delay_multiplier: float = ..., - retry_delay_jitter_factor: float = ..., + retry_delay_jitter_factor: float = ... ) -> None: ... @@ -444,11 +532,13 @@ def get_server_info( impersonated_user: t.Optional[str] = ..., bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., default_access_mode: str = ..., + bookmark_manager: t.Union[BookmarkManager, + BookmarkManager, None] = ..., # undocumented/unsupported options initial_retry_delay: float = ..., retry_delay_multiplier: float = ..., - retry_delay_jitter_factor: float = ..., + retry_delay_jitter_factor: float = ... ) -> ServerInfo: ... diff --git a/neo4j/_sync/io/_pool.py b/neo4j/_sync/io/_pool.py index d3ea665f4..947555a01 100644 --- a/neo4j/_sync/io/_pool.py +++ b/neo4j/_sync/io/_pool.py @@ -515,9 +515,9 @@ def fetch_routing_info( cx = self._acquire(address, deadline, None) try: routing_table = cx.route( - database or self.workspace_config.database, - imp_user or self.workspace_config.impersonated_user, - bookmarks + database=database or self.workspace_config.database, + imp_user=imp_user or self.workspace_config.impersonated_user, + bookmarks=bookmarks ) finally: self.release(cx) diff --git a/neo4j/_sync/work/result.py b/neo4j/_sync/work/result.py index 66e6b770c..e86305caa 100644 --- a/neo4j/_sync/work/result.py +++ b/neo4j/_sync/work/result.py @@ -80,6 +80,7 @@ def __init__(self, connection, fetch_size, on_closed, on_error): self._keys = None self._record_buffer = deque() self._summary = None + self._database = None self._bookmark = None self._raw_qid = -1 self._fetch_size = fetch_size @@ -127,7 +128,9 @@ def _run( "query": query_text, "parameters": parameters, "server": self._connection.server_info, + "database": db, } + self._database = db def on_attached(metadata): self._metadata.update(metadata) @@ -189,6 +192,7 @@ def on_success(summary_metadata): return self._metadata.update(summary_metadata) self._bookmark = summary_metadata.get("bookmark") + self._database = summary_metadata.get("db", self._database) self._connection.pull( n=self._fetch_size, @@ -220,6 +224,7 @@ def on_success(summary_metadata): self._discarding = False self._metadata.update(summary_metadata) self._bookmark = summary_metadata.get("bookmark") + self._database = summary_metadata.get("db", self._database) # This was the last page received, discard the rest self._connection.discard( diff --git a/neo4j/_sync/work/session.py b/neo4j/_sync/work/session.py index 743e2965d..7f4d29c5b 100644 --- a/neo4j/_sync/work/session.py +++ b/neo4j/_sync/work/session.py @@ -36,10 +36,7 @@ from ..._async_compat import sleep from ..._async_compat.util import Util from ..._conf import SessionConfig -from ..._meta import ( - deprecated, - deprecation_warn, -) +from ..._meta import deprecated from ...api import ( Bookmarks, READ_ACCESS, @@ -100,9 +97,11 @@ class Session(Workspace): _state_failed = False def __init__(self, pool, session_config): - super().__init__(pool, session_config) assert isinstance(session_config, SessionConfig) - self._bookmarks = self._prepare_bookmarks(session_config.bookmarks) + super().__init__(pool, session_config) + self._initialize_bookmarks(session_config.bookmarks) + if not session_config.ignore_bookmark_manager: + self._bookmark_manager = session_config.bookmark_manager def __enter__(self) -> Session: return self @@ -110,27 +109,12 @@ def __enter__(self) -> Session: def __exit__(self, exception_type, exception_value, traceback): if exception_type: if issubclass(exception_type, asyncio.CancelledError): - self._handle_cancellation(message="__aexit__") + self._handle_cancellation(message="__exit__") self._closed = True return self._state_failed = True self.close() - def _prepare_bookmarks(self, bookmarks): - if isinstance(bookmarks, Bookmarks): - return tuple(bookmarks.raw_values) - if hasattr(bookmarks, "__iter__"): - deprecation_warn( - "Passing an iterable as `bookmarks` to `Session` is " - "deprecated. Please use a `Bookmarks` instance.", - stack_level=5 - ) - return tuple(bookmarks) - if not bookmarks: - return () - raise TypeError("Bookmarks must be an instance of Bookmarks or an " - "iterable of raw bookmarks (deprecated).") - def _connect(self, access_mode, **access_kwargs): if access_mode is None: access_mode = self._config.default_access_mode @@ -147,10 +131,6 @@ def _disconnect(self, sync=False): self._handle_cancellation(message="_disconnect") raise - def _collect_bookmark(self, bookmark): - if bookmark: - self._bookmarks = bookmark, - def _handle_cancellation(self, message="General"): self._transaction = None self._auto_result = None @@ -165,7 +145,8 @@ def _handle_cancellation(self, message="General"): def _result_closed(self): if self._auto_result: - self._collect_bookmark(self._auto_result._bookmark) + self._update_bookmark(self._auto_result._database, + self._auto_result._bookmark) self._auto_result = None self._disconnect() @@ -196,7 +177,10 @@ def close(self) -> None: if self._state_failed is False: try: self._auto_result.consume() - self._collect_bookmark(self._auto_result._bookmark) + self._update_bookmark( + self._auto_result._database, + self._auto_result._bookmark + ) except Exception as error: # TODO: Investigate potential non graceful close states self._auto_result = None @@ -302,10 +286,11 @@ def run( cx, self._config.fetch_size, self._result_closed, self._result_error ) + bookmarks = self._get_all_bookmarks() self._auto_result._run( query, parameters, self._config.database, self._config.impersonated_user, self._config.default_access_mode, - self._bookmarks, **kwargs + bookmarks, **kwargs ) return self._auto_result @@ -336,7 +321,8 @@ def last_bookmark(self) -> t.Optional[str]: self._auto_result.consume() if self._transaction and self._transaction._closed: - self._collect_bookmark(self._transaction._bookmark) + self._update_bookmark(self._transaction._database, + self._transaction._bookmark) self._transaction = None if self._bookmarks: @@ -377,14 +363,16 @@ def last_bookmarks(self) -> Bookmarks: self._auto_result.consume() if self._transaction and self._transaction._closed(): - self._collect_bookmark(self._transaction._bookmark) + self._update_bookmark(self._transaction._database, + self._transaction._bookmark) self._transaction = None return Bookmarks.from_raw_values(self._bookmarks) def _transaction_closed_handler(self): if self._transaction: - self._collect_bookmark(self._transaction._bookmark) + self._update_bookmark(self._transaction._database, + self._transaction._bookmark) self._transaction = None self._disconnect() @@ -408,9 +396,10 @@ def _open_transaction( self._transaction_error_handler, self._transaction_cancel_handler ) + bookmarks = self._get_all_bookmarks() self._transaction._begin( self._config.database, self._config.impersonated_user, - self._bookmarks, access_mode, metadata, timeout + bookmarks, access_mode, metadata, timeout ) def begin_transaction( diff --git a/neo4j/_sync/work/transaction.py b/neo4j/_sync/work/transaction.py index 8a297dd07..41c1fd895 100644 --- a/neo4j/_sync/work/transaction.py +++ b/neo4j/_sync/work/transaction.py @@ -44,6 +44,7 @@ def __init__(self, connection, fetch_size, on_closed, on_error, connection, self._error_handler ) self._bookmark = None + self._database = None self._results = [] self._closed_flag = False self._last_error = None @@ -69,6 +70,7 @@ def _exit(self, exception_type, exception_value, traceback): def _begin( self, database, imp_user, bookmarks, access_mode, metadata, timeout ): + self._database = database self._connection.begin( bookmarks=bookmarks, metadata=metadata, timeout=timeout, mode=access_mode, db=database, imp_user=imp_user @@ -168,6 +170,7 @@ def _commit(self): self._connection.send_all() self._connection.fetch_all() self._bookmark = metadata.get("bookmark") + self._database = metadata.get("db", self._database) except asyncio.CancelledError: self._on_cancel() raise diff --git a/neo4j/_sync/work/workspace.py b/neo4j/_sync/work/workspace.py index c780241a5..844bb96c7 100644 --- a/neo4j/_sync/work/workspace.py +++ b/neo4j/_sync/work/workspace.py @@ -20,12 +20,14 @@ import asyncio +from ..._async_compat.util import Util from ..._conf import WorkspaceConfig from ..._deadline import Deadline from ..._meta import ( deprecation_warn, unclosed_resource_warn, ) +from ...api import Bookmarks from ...exceptions import ( ServiceUnavailable, SessionError, @@ -44,7 +46,10 @@ def __init__(self, pool, config): self._connection_access_mode = None # Sessions are supposed to cache the database on which to operate. self._cached_database = False - self._bookmarks = None + self._bookmarks = () + self._initial_bookmarks = () + self._bookmark_manager = None + self._last_from_bookmark_manager = None # Workspace has been closed. self._closed = False @@ -77,6 +82,81 @@ def _set_cached_database(self, database): self._cached_database = True self._config.database = database + def _initialize_bookmarks(self, bookmarks): + if isinstance(bookmarks, Bookmarks): + prepared_bookmarks = tuple(bookmarks.raw_values) + elif hasattr(bookmarks, "__iter__"): + deprecation_warn( + "Passing an iterable as `bookmarks` to `Session` is " + "deprecated. Please use a `Bookmarks` instance.", + stack_level=5 + ) + prepared_bookmarks = tuple(bookmarks) + elif not bookmarks: + prepared_bookmarks = () + else: + raise TypeError("Bookmarks must be an instance of Bookmarks or an " + "iterable of raw bookmarks (deprecated).") + self._initial_bookmarks = self._bookmarks = prepared_bookmarks + + def _get_bookmarks(self, database): + if self._bookmark_manager is None: + return self._bookmarks + + # For 4.3- support: the server will not send the resolved home + # database back. To avoid confusion between `None` as in "all + # database" and `None` as in "home database" we re-write the + # home database to `""`, which otherwise is an invalid database + # name. It will not work properly either way, as the home database + # can change (server config change or client side user change). + if database is None: + database = "" + self._last_from_bookmark_manager = tuple({ + *Util.callback( + self._bookmark_manager.get_bookmarks, database + ), + *self._initial_bookmarks + }) + return self._last_from_bookmark_manager + + def _get_all_bookmarks(self): + if self._bookmark_manager is None: + return self._bookmarks + + self._last_from_bookmark_manager = tuple({ + *Util.callback( + self._bookmark_manager.get_all_bookmarks, + ), + *self._initial_bookmarks + }) + return self._last_from_bookmark_manager + + def _update_bookmarks(self, database, new_bookmarks): + if not new_bookmarks: + return + self._initial_bookmarks = () + self._bookmarks = new_bookmarks + if self._bookmark_manager is None: + return + previous_bookmarks = self._last_from_bookmark_manager + # For 4.3- support: the server will not send the resolved home + # database back. To avoid confusion between `None` as in "all + # database" and `None` as in "home database" we re-write the home + # database to `""`, which otherwise is an invalid database name. + if database is None: + database = "" + Util.callback( + self._bookmark_manager.update_bookmarks, + database, previous_bookmarks, new_bookmarks + ) + + def _update_bookmark(self, database, bookmark): + if not bookmark: + return + if not database: + database = self._config.database + self._update_bookmarks(database, (bookmark,)) + def _connect(self, access_mode, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout if self._connection: @@ -99,7 +179,7 @@ def _connect(self, access_mode, **acquire_kwargs): self._pool.update_routing_table( database=self._config.database, imp_user=self._config.impersonated_user, - bookmarks=self._bookmarks, + bookmarks=self._get_bookmarks("system"), acquisition_timeout=acquisition_timeout, database_callback=self._set_cached_database ) @@ -107,7 +187,7 @@ def _connect(self, access_mode, **acquire_kwargs): "access_mode": access_mode, "timeout": acquisition_timeout, "database": self._config.database, - "bookmarks": self._bookmarks, + "bookmarks": self._get_bookmarks("system"), "liveness_check_timeout": None, } acquire_kwargs_.update(acquire_kwargs) diff --git a/neo4j/api.py b/neo4j/api.py index 7db50c907..a95c7216e 100644 --- a/neo4j/api.py +++ b/neo4j/api.py @@ -20,6 +20,7 @@ from __future__ import annotations +import abc import typing as t from urllib.parse import ( parse_qs, @@ -35,6 +36,12 @@ from .exceptions import ConfigurationError +if t.TYPE_CHECKING: + from typing_extensions import Protocol as _Protocol +else: + _Protocol = object + + READ_ACCESS: te.Final[str] = "READ" WRITE_ACCESS: te.Final[str] = "WRITE" @@ -235,9 +242,11 @@ def __repr__(self) -> str: ) def __bool__(self) -> bool: + """True if there are bookmarks in the container.""" return bool(self._raw_values) def __add__(self, other: Bookmarks) -> Bookmarks: + """Add multiple containers together.""" if isinstance(other, Bookmarks): if not other: return self @@ -363,6 +372,121 @@ def from_bytes(cls, b: bytes) -> Version: return Version(b[-1], b[-2]) +class BookmarkManager(_Protocol, metaclass=abc.ABCMeta): + """Class to manage bookmarks throughout the driver's lifetime. + + Neo4j clusters are eventually consistent, meaning that there is no + guarantee a query will be able to read changes made by a previous query. + For cases where such a guarantee is necessary, the server provides + bookmarks to the client. A bookmark is an abstract token that represents + some state of the database. By passing one or multiple bookmarks along + with a query, the server will make sure that the query will not get + executed before the represented state(s) (or a later state) have been + established. + + The bookmark manager is an interface used by the driver for keeping + track of the bookmarks and this way keeping sessions automatically + consistent. Configure the driver to use a specific bookmark manager with + :ref:`bookmark-manager-ref`. + + This class is just an abstract base class that defines the required + interface. Create a child class to implement a specific bookmark manager + or make user of the default implementation provided by the driver through + :meth:`.GraphDatabase.bookmark_manager()`. + + .. note:: + All methods must be concurrency safe. + + Generally, all methods need to be able to cope with getting passed a + ``database`` parameter that is (until then) unknown to the manager. + + .. versionadded:: 5.0 + """ + + @abc.abstractmethod + def update_bookmarks( + self, database: str, previous_bookmarks: t.Collection[str], + new_bookmarks: t.Collection[str] + ) -> None: + """Handle bookmark updates. + + :param database: + The database which the bookmarks belong to + :param previous_bookmarks: + The bookmarks used at the start of a transaction + :param new_bookmarks: + The new bookmarks retrieved at the end of a transaction + """ + ... + + @abc.abstractmethod + def get_bookmarks(self, database: str) -> t.Collection[str]: + """Return the bookmarks for a given database. + + :param database: The database which the bookmarks belong to + + :returns: The bookmarks for the given database + """ + ... + + @abc.abstractmethod + def get_all_bookmarks(self) -> t.Collection[str]: + """Return all bookmarks for all known databases. + + :returns: The collected bookmarks. + """ + ... + + @abc.abstractmethod + def forget(self, databases: t.Iterable[str]) -> None: + """Forget the bookmarks for the given databases. + + This method is not called by the driver. + Forgetting unused databases is the user's responsibility. + + :param databases: + The databases which the bookmarks will be removed for. + """ + ... + + +class AsyncBookmarkManager(_Protocol, metaclass=abc.ABCMeta): + """Same as :class:`.BookmarkManager` but with async methods. + + The driver comes with a default implementation of the async bookmark + manager accessible through :attr:`.AsyncGraphDatabase.bookmark_manager()`. + + .. versionadded:: 5.0 + """ + + @abc.abstractmethod + async def update_bookmarks( + self, database: str, previous_bookmarks: t.Collection[str], + new_bookmarks: t.Collection[str] + ) -> None: + ... + + update_bookmarks.__doc__ = BookmarkManager.update_bookmarks.__doc__ + + @abc.abstractmethod + async def get_bookmarks(self, database: str) -> t.Collection[str]: + ... + + get_bookmarks.__doc__ = BookmarkManager.get_bookmarks.__doc__ + + @abc.abstractmethod + async def get_all_bookmarks(self) -> t.Collection[str]: + ... + + get_all_bookmarks.__doc__ = BookmarkManager.get_all_bookmarks.__doc__ + + @abc.abstractmethod + async def forget(self, databases: t.Iterable[str]) -> None: + ... + + forget.__doc__ = BookmarkManager.forget.__doc__ + + def parse_neo4j_uri(uri): parsed = urlparse(uri) diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index c6aa85e6a..851865c6b 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -55,6 +55,9 @@ def __init__(self, rd, wr): self.drivers = {} self.custom_resolutions = {} self.dns_resolutions = {} + self.bookmark_managers = {} + self.bookmarks_consumptions = {} + self.bookmarks_supplies = {} self.sessions = {} self.results = {} self.errors = {} diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 8f82758f0..6f939e9cb 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -22,6 +22,7 @@ from os import path import neo4j +import neo4j.api from neo4j._async_compat.util import AsyncUtil from .. import ( @@ -146,7 +147,6 @@ async def NewDriver(backend, data): kwargs["trusted_certificates"] = neo4j.TrustCustomCAs(*cert_paths) data.mark_item_as_read_if_equals("livenessCheckTimeoutMs", None) - data.mark_item_as_read("domainNameResolverRegistered") driver = neo4j.AsyncGraphDatabase.driver( data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs ) @@ -243,6 +243,84 @@ async def DomainNameResolutionCompleted(backend, data): backend.dns_resolutions[data["requestId"]] = data["addresses"] +async def NewBookmarkManager(backend, data): + bookmark_manager_id = backend.next_key() + + bmm_kwargs = {} + data.mark_item_as_read("initialBookmarks", recursive=True) + bmm_kwargs["initial_bookmarks"] = data.get("initialBookmarks") + if data.get("bookmarksSupplierRegistered"): + bmm_kwargs["bookmarks_supplier"] = bookmarks_supplier( + backend, bookmark_manager_id + ) + if data.get("bookmarksConsumerRegistered"): + bmm_kwargs["bookmarks_consumer"] = bookmarks_consumer( + backend, bookmark_manager_id + ) + + bookmark_manager = neo4j.AsyncGraphDatabase.bookmark_manager(**bmm_kwargs) + backend.bookmark_managers[bookmark_manager_id] = bookmark_manager + await backend.send_response("BookmarkManager", {"id": bookmark_manager_id}) + + +async def BookmarkManagerClose(backend, data): + bookmark_manager_id = data["id"] + del backend.bookmark_managers[bookmark_manager_id] + await backend.send_response("BookmarkManager", {"id": bookmark_manager_id}) + + +def bookmarks_supplier(backend, bookmark_manager_id): + async def supplier(database): + key = backend.next_key() + await backend.send_response("BookmarksSupplierRequest", { + "id": key, + "bookmarkManagerId": bookmark_manager_id, + "database": database + }) + if not await backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.bookmarks_supplies: + raise RuntimeError( + "Backend did not receive expected " + "BookmarksSupplierCompleted message for id %s" % key + ) + return backend.bookmarks_supplies.pop(key) + + return supplier + + +async def BookmarksSupplierCompleted(backend, data): + backend.bookmarks_supplies[data["requestId"]] = \ + neo4j.Bookmarks.from_raw_values(data["bookmarks"]) + + +def bookmarks_consumer(backend, bookmark_manager_id): + async def consumer(database, bookmarks): + key = backend.next_key() + await backend.send_response("BookmarksConsumerRequest", { + "id": key, + "bookmarkManagerId": bookmark_manager_id, + "database": database, + "bookmarks": list(bookmarks.raw_values) + }) + if not await backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.bookmarks_consumptions: + raise RuntimeError( + "Backend did not receive expected " + "BookmarksConsumerCompleted message for id %s" % key + ) + del backend.bookmarks_consumptions[key] + + return consumer + + +async def BookmarksConsumerCompleted(backend, data): + backend.bookmarks_consumptions[data["requestId"]] = True + + async def DriverClose(backend, data): key = data["driverId"] driver = backend.drivers[key] @@ -277,17 +355,25 @@ async def NewSession(backend, data): access_mode = neo4j.WRITE_ACCESS else: raise ValueError("Unknown access mode:" + access_mode) - bookmarks = None - if "bookmarks" in data and data["bookmarks"]: - bookmarks = neo4j.Bookmarks.from_raw_values(data["bookmarks"]) config = { - "default_access_mode": access_mode, - "bookmarks": bookmarks, - "database": data["database"], - "fetch_size": data.get("fetchSize", None), - "impersonated_user": data.get("impersonatedUser", None), - + "default_access_mode": access_mode, + "database": data["database"], } + if data.get("bookmarks") is not None: + config["bookmarks"] = neo4j.Bookmarks.from_raw_values( + data["bookmarks"] + ) + if data.get("bookmarkManagerId") is not None: + config["bookmark_manager"] = backend.bookmark_managers[ + data["bookmarkManagerId"] + ] + for (conf_name, data_name) in ( + ("fetch_size", "fetchSize"), + ("impersonated_user", "impersonatedUser"), + ("ignore_bookmark_manager", "ignoreBookmarkManager"), + ): + if data_name in data: + config[conf_name] = data[data_name] session = driver.session(**config) key = backend.next_key() backend.sessions[key] = SessionTracker(session) diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index a07a89d50..5dae1753d 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -55,6 +55,9 @@ def __init__(self, rd, wr): self.drivers = {} self.custom_resolutions = {} self.dns_resolutions = {} + self.bookmark_managers = {} + self.bookmarks_consumptions = {} + self.bookmarks_supplies = {} self.sessions = {} self.results = {} self.errors = {} diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 9274479fe..c46b65db9 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -22,6 +22,7 @@ from os import path import neo4j +import neo4j.api from neo4j._async_compat.util import Util from .. import ( @@ -146,7 +147,6 @@ def NewDriver(backend, data): kwargs["trusted_certificates"] = neo4j.TrustCustomCAs(*cert_paths) data.mark_item_as_read_if_equals("livenessCheckTimeoutMs", None) - data.mark_item_as_read("domainNameResolverRegistered") driver = neo4j.GraphDatabase.driver( data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs ) @@ -243,6 +243,84 @@ def DomainNameResolutionCompleted(backend, data): backend.dns_resolutions[data["requestId"]] = data["addresses"] +def NewBookmarkManager(backend, data): + bookmark_manager_id = backend.next_key() + + bmm_kwargs = {} + data.mark_item_as_read("initialBookmarks", recursive=True) + bmm_kwargs["initial_bookmarks"] = data.get("initialBookmarks") + if data.get("bookmarksSupplierRegistered"): + bmm_kwargs["bookmarks_supplier"] = bookmarks_supplier( + backend, bookmark_manager_id + ) + if data.get("bookmarksConsumerRegistered"): + bmm_kwargs["bookmarks_consumer"] = bookmarks_consumer( + backend, bookmark_manager_id + ) + + bookmark_manager = neo4j.GraphDatabase.bookmark_manager(**bmm_kwargs) + backend.bookmark_managers[bookmark_manager_id] = bookmark_manager + backend.send_response("BookmarkManager", {"id": bookmark_manager_id}) + + +def BookmarkManagerClose(backend, data): + bookmark_manager_id = data["id"] + del backend.bookmark_managers[bookmark_manager_id] + backend.send_response("BookmarkManager", {"id": bookmark_manager_id}) + + +def bookmarks_supplier(backend, bookmark_manager_id): + def supplier(database): + key = backend.next_key() + backend.send_response("BookmarksSupplierRequest", { + "id": key, + "bookmarkManagerId": bookmark_manager_id, + "database": database + }) + if not backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.bookmarks_supplies: + raise RuntimeError( + "Backend did not receive expected " + "BookmarksSupplierCompleted message for id %s" % key + ) + return backend.bookmarks_supplies.pop(key) + + return supplier + + +def BookmarksSupplierCompleted(backend, data): + backend.bookmarks_supplies[data["requestId"]] = \ + neo4j.Bookmarks.from_raw_values(data["bookmarks"]) + + +def bookmarks_consumer(backend, bookmark_manager_id): + def consumer(database, bookmarks): + key = backend.next_key() + backend.send_response("BookmarksConsumerRequest", { + "id": key, + "bookmarkManagerId": bookmark_manager_id, + "database": database, + "bookmarks": list(bookmarks.raw_values) + }) + if not backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.bookmarks_consumptions: + raise RuntimeError( + "Backend did not receive expected " + "BookmarksConsumerCompleted message for id %s" % key + ) + del backend.bookmarks_consumptions[key] + + return consumer + + +def BookmarksConsumerCompleted(backend, data): + backend.bookmarks_consumptions[data["requestId"]] = True + + def DriverClose(backend, data): key = data["driverId"] driver = backend.drivers[key] @@ -277,17 +355,25 @@ def NewSession(backend, data): access_mode = neo4j.WRITE_ACCESS else: raise ValueError("Unknown access mode:" + access_mode) - bookmarks = None - if "bookmarks" in data and data["bookmarks"]: - bookmarks = neo4j.Bookmarks.from_raw_values(data["bookmarks"]) config = { - "default_access_mode": access_mode, - "bookmarks": bookmarks, - "database": data["database"], - "fetch_size": data.get("fetchSize", None), - "impersonated_user": data.get("impersonatedUser", None), - + "default_access_mode": access_mode, + "database": data["database"], } + if data.get("bookmarks") is not None: + config["bookmarks"] = neo4j.Bookmarks.from_raw_values( + data["bookmarks"] + ) + if data.get("bookmarkManagerId") is not None: + config["bookmark_manager"] = backend.bookmark_managers[ + data["bookmarkManagerId"] + ] + for (conf_name, data_name) in ( + ("fetch_size", "fetchSize"), + ("impersonated_user", "impersonatedUser"), + ("ignore_bookmark_manager", "ignoreBookmarkManager"), + ): + if data_name in data: + config[conf_name] = data[data_name] session = driver.session(**config) key = backend.next_key() backend.sessions[key] = SessionTracker(session) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 8f72b4879..18710693f 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -16,6 +16,7 @@ "test_subtest_skips.tz_id" }, "features": { + "Feature:API:BookmarkManager": true, "Feature:API:ConnectionAcquisitionTimeout": true, "Feature:API:Driver:GetServerInfo": true, "Feature:API:Driver.IsEncrypted": true, @@ -48,6 +49,7 @@ "Optimization:ConnectionReuse": true, "Optimization:EagerTransactionBegin": true, "Optimization:ImplicitDefaultArguments": true, + "Optimization:MinimalBookmarksSet": true, "Optimization:MinimalResets": true, "Optimization:PullPipelining": true, "Optimization:ResultListFetchAll": "The idiomatic way to cast to list is indistinguishable from iterating over the result.", diff --git a/tests/_async_compat/__init__.py b/tests/_async_compat/__init__.py index b9aec02c4..12a189840 100644 --- a/tests/_async_compat/__init__.py +++ b/tests/_async_compat/__init__.py @@ -17,12 +17,16 @@ from .mark_decorator import ( + AsyncTestDecorators, mark_async_test, mark_sync_test, + TestDecorators, ) __all__ = [ + "AsyncTestDecorators", "mark_async_test", "mark_sync_test", + "TestDecorators", ] diff --git a/tests/_async_compat/mark_decorator.py b/tests/_async_compat/mark_decorator.py index cfb42f453..b89754cb2 100644 --- a/tests/_async_compat/mark_decorator.py +++ b/tests/_async_compat/mark_decorator.py @@ -24,3 +24,14 @@ def mark_sync_test(f): return f + + +class AsyncTestDecorators: + mark_async_only_test = mark_async_test + + +class TestDecorators: + @staticmethod + def mark_async_only_test(f): + skip_decorator = pytest.mark.skip("Async only test") + return skip_decorator(f) diff --git a/tests/unit/async_/conftest.py b/tests/unit/async_/conftest.py new file mode 100644 index 000000000..9d171987d --- /dev/null +++ b/tests/unit/async_/conftest.py @@ -0,0 +1,19 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .fixtures import * # necessary for pytest to discover the fixtures diff --git a/tests/unit/async_/fixtures/__init__.py b/tests/unit/async_/fixtures/__init__.py new file mode 100644 index 000000000..c3e907d44 --- /dev/null +++ b/tests/unit/async_/fixtures/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .fake_connection import * +from .fake_pool import * diff --git a/tests/unit/async_/work/_fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py similarity index 92% rename from tests/unit/async_/work/_fake_connection.py rename to tests/unit/async_/fixtures/fake_connection.py index 0966b1556..f0d0070c1 100644 --- a/tests/unit/async_/work/_fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -25,6 +25,14 @@ from neo4j._deadline import Deadline +__all__ = [ + "async_fake_connection_generator", + "async_fake_connection", + "async_scripted_connection_generator", + "async_scripted_connection", +] + + @pytest.fixture def async_fake_connection_generator(session_mocker): mock = session_mocker.mock_module @@ -158,15 +166,15 @@ def __getattr__(self, name): parent = super() def build_message_handler(name): - try: - expected_message, scripted_callbacks = \ - self._script[self._script_pos] - except IndexError: - pytest.fail("End of scripted connection reached.") - assert name == expected_message - self._script_pos += 1 - def func(*args, **kwargs): + try: + expected_message, scripted_callbacks = \ + self._script[self._script_pos] + except IndexError: + pytest.fail("End of scripted connection reached.") + assert name == expected_message + self._script_pos += 1 + async def callback(): for cb_name, default_cb_args in ( ("on_ignored", ({},)), @@ -176,8 +184,10 @@ async def callback(): ("on_summary", ()), ): cb = kwargs.get(cb_name, None) - if (not callable(cb) - or cb_name not in scripted_callbacks): + if ( + not callable(cb) + or cb_name not in scripted_callbacks + ): continue cb_args = scripted_callbacks[cb_name] if cb_args is None: diff --git a/tests/unit/async_/fixtures/fake_pool.py b/tests/unit/async_/fixtures/fake_pool.py new file mode 100644 index 000000000..ba4636f7e --- /dev/null +++ b/tests/unit/async_/fixtures/fake_pool.py @@ -0,0 +1,45 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._async.io._pool import AsyncIOPool + + +__all__ = [ + "fake_pool", +] + + +@pytest.fixture +def fake_pool(async_fake_connection_generator, mocker): + pool = mocker.AsyncMock(spec=AsyncIOPool) + assert not hasattr(pool, "acquired_connection_mocks") + pool.buffered_connection_mocks = [] + pool.acquired_connection_mocks = [] + + def acquire_side_effect(*_, **__): + if pool.buffered_connection_mocks: + connection = pool.buffered_connection_mocks.pop() + else: + connection = async_fake_connection_generator() + pool.acquired_connection_mocks.append(connection) + return connection + + pool.acquire.side_effect = acquire_side_effect + return pool diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index 7538e127b..eb74241ca 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -230,7 +230,7 @@ async def test_hint_recv_timeout_seconds( sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x3.PACKER_CLS, unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) - sockets.client.settimeout = mocker.AsyncMock() + sockets.client.settimeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.0", "hints": hints} ) diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index 285aa9744..c88b4af12 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -244,7 +244,7 @@ async def test_hint_recv_timeout_seconds( sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x4.PACKER_CLS, unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) - sockets.client.settimeout = mocker.MagicMock() + sockets.client.settimeout = mocker.Mock() await sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index 07b136551..b9bbc4e42 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -38,7 +38,6 @@ ) from ...._async_compat import mark_async_test -from ..work import async_fake_connection_generator # needed as fixture ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") diff --git a/tests/unit/async_/test_addressing.py b/tests/unit/async_/test_addressing.py index 0c3fcfc54..75a036f9c 100644 --- a/tests/unit/async_/test_addressing.py +++ b/tests/unit/async_/test_addressing.py @@ -34,7 +34,7 @@ @mark_async_test -async def test_address_resolve(): +async def test_address_resolve() -> None: address = Address(("127.0.0.1", 7687)) resolved = AsyncNetworkUtil.resolve_address(address) resolved = await AsyncUtil.list(resolved) @@ -45,7 +45,7 @@ async def test_address_resolve(): @mark_async_test -async def test_address_resolve_with_custom_resolver_none(): +async def test_address_resolve_with_custom_resolver_none() -> None: address = Address(("127.0.0.1", 7687)) resolved = AsyncNetworkUtil.resolve_address(address, resolver=None) resolved = await AsyncUtil.list(resolved) @@ -64,7 +64,9 @@ async def test_address_resolve_with_custom_resolver_none(): ) @mark_async_test -async def test_address_resolve_with_unresolvable_address(test_input, expected): +async def test_address_resolve_with_unresolvable_address( + test_input, expected +) -> None: with pytest.raises(expected): await AsyncUtil.list( AsyncNetworkUtil.resolve_address(test_input, resolver=None) @@ -73,7 +75,7 @@ async def test_address_resolve_with_unresolvable_address(test_input, expected): @mark_async_test @pytest.mark.parametrize("resolver_type", ("sync", "async")) -async def test_address_resolve_with_custom_resolver(resolver_type): +async def test_address_resolve_with_custom_resolver(resolver_type) -> None: def custom_resolver_sync(_): return [("127.0.0.1", 7687), ("localhost", 1234)] @@ -98,9 +100,11 @@ async def custom_resolver_async(_): @mark_async_test -async def test_address_unresolve(): +async def test_address_unresolve() -> None: + def custom_resolver(_): + return custom_resolved + custom_resolved = [("127.0.0.1", 7687), ("localhost", 4321)] - custom_resolver = lambda _: custom_resolved address = Address(("foobar", 1234)) unresolved = address.unresolved @@ -109,9 +113,9 @@ async def test_address_unresolve(): resolved = AsyncNetworkUtil.resolve_address( address, family=AF_INET, resolver=custom_resolver ) - resolved = await AsyncUtil.list(resolved) - custom_resolved = sorted(Address(a) for a in custom_resolved) - unresolved = sorted(a.unresolved for a in resolved) - assert custom_resolved == unresolved - assert (list(map(lambda a: a.__class__, custom_resolved)) - == list(map(lambda a: a.__class__, unresolved))) + resolved_list = await AsyncUtil.list(resolved) + custom_resolved_addresses = sorted(Address(a) for a in custom_resolved) + unresolved_list = sorted(a.unresolved for a in resolved_list) + assert custom_resolved_addresses == unresolved_list + assert (list(map(lambda a: a.__class__, custom_resolved_addresses)) + == list(map(lambda a: a.__class__, unresolved_list))) diff --git a/tests/unit/async_/test_bookmark_manager.py b/tests/unit/async_/test_bookmark_manager.py new file mode 100644 index 000000000..9532f670c --- /dev/null +++ b/tests/unit/async_/test_bookmark_manager.py @@ -0,0 +1,284 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import itertools +import typing as t + +import pytest + +import neo4j +from neo4j._async.bookmark_manager import AsyncNeo4jBookmarkManager +from neo4j._async_compat.util import AsyncUtil +from neo4j._meta import copy_signature +from neo4j.api import Bookmarks + +from ..._async_compat import mark_async_test + + +supplier_async_options = (True, False) if AsyncUtil.is_async_code else (False,) +consumer_async_options = supplier_async_options + + +@copy_signature(neo4j.AsyncGraphDatabase.bookmark_manager) +def bookmark_manager(*args, **kwargs): + with pytest.warns(neo4j.ExperimentalWarning, match="bookmark manager"): + return neo4j.AsyncGraphDatabase.bookmark_manager(*args, **kwargs) + + +@pytest.mark.parametrize("db", ("foobar", "system")) +@mark_async_test +async def test_return_empty_if_db_doesnt_exists(db) -> None: + bmm = bookmark_manager() + + assert set(await bmm.get_bookmarks(db)) == set() + + +@pytest.mark.parametrize("db", ("db1", "db2", "db3")) +@mark_async_test +async def test_return_initial_bookmarks_for_the_given_db(db) -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1"], + "db2": [], + "db3": ["db3:bm1", "db3:bm2"], + "db4": ["db4:bm4"] + } + bmm = bookmark_manager(initial_bookmarks=initial_bookmarks) + + assert set(await bmm.get_bookmarks(db)) == set(initial_bookmarks[db]) + + +@pytest.mark.parametrize("db", ("db1", "db2", "db3")) +@pytest.mark.parametrize("supplier_async", supplier_async_options) +@mark_async_test +async def test_return_get_bookmarks_from_bookmarks_supplier( + db, mocker, supplier_async +) -> None: + extra_bookmarks = ["foo:bm1", "bar:bm2", "foo:bm1"] + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1"], + "db2": [], + "db3": ["db3:bm1", "db3:bm2"], + "db4": ["db4:bm4"] + } + mock_cls = mocker.AsyncMock if supplier_async else mocker.Mock + supplier = mock_cls( + return_value=Bookmarks.from_raw_values(extra_bookmarks) + ) + bmm = bookmark_manager(initial_bookmarks=initial_bookmarks, + bookmarks_supplier=supplier) + + assert set(await bmm.get_bookmarks(db)) == { + *extra_bookmarks, *initial_bookmarks.get(db, []) + } + if supplier_async: + supplier.assert_awaited_once_with(db) + else: + supplier.assert_called_once_with(db) + + +@pytest.mark.parametrize("with_initial_bookmarks", (True, False)) +@mark_async_test +async def test_return_all_bookmarks(with_initial_bookmarks) -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1"], + "db2": [], + "db3": ["db3:bm1", "db3:bm2"], + "db4": ["db4:bm4"], + "db5": ["db3:bm1"] + } + bmm = bookmark_manager( + initial_bookmarks=initial_bookmarks if with_initial_bookmarks else None + ) + + all_bookmarks = await bmm.get_all_bookmarks() + + if with_initial_bookmarks: + assert all_bookmarks == set( + itertools.chain.from_iterable(initial_bookmarks.values()) + ) + else: + assert all_bookmarks == set() + + +@pytest.mark.parametrize("with_initial_bookmarks", (True, False)) +@pytest.mark.parametrize("supplier_async", supplier_async_options) +@mark_async_test +async def test_return_enriched_bookmarks_list_with_supplied_bookmarks( + with_initial_bookmarks, supplier_async, mocker +) -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1"], + "db2": [], + "db3": ["db3:bm1", "db3:bm2"], + "db4": ["db4:bm4"], + } + extra_bookmarks = ["foo:bm1", "bar:bm2", "db3:bm1", "foo:bm1"] + mock_cls = mocker.AsyncMock if supplier_async else mocker.Mock + supplier = mock_cls( + return_value=Bookmarks.from_raw_values(extra_bookmarks) + ) + bmm = bookmark_manager( + initial_bookmarks=(initial_bookmarks + if with_initial_bookmarks else None), + bookmarks_supplier=supplier + ) + + all_bookmarks = await bmm.get_all_bookmarks() + + if with_initial_bookmarks: + assert all_bookmarks == set( + itertools.chain(*initial_bookmarks.values(), extra_bookmarks) + ) + else: + assert all_bookmarks == set(extra_bookmarks) + if supplier_async: + supplier.assert_awaited_once_with(None) + else: + supplier.assert_called_once_with(None) + + +@mark_async_test +async def test_chains_bookmarks_for_existing_db() -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1"], + "db2": [], + "db3": ["db3:bm1", "db3:bm2"], + "db4": ["db4:bm4"], + } + bmm = bookmark_manager(initial_bookmarks=initial_bookmarks) + await bmm.update_bookmarks("db3", ["db3:bm1"], ["db3:bm3"]) + new_bookmarks = await bmm.get_bookmarks("db3") + all_bookmarks = await bmm.get_all_bookmarks() + + assert new_bookmarks == {"db3:bm2", "db3:bm3"} + assert all_bookmarks == set( + itertools.chain.from_iterable(initial_bookmarks.values()) + ) - {"db3:bm1"} | {"db3:bm2", "db3:bm3"} + + +@mark_async_test +async def test_add_bookmarks_for_a_non_existing_database() -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1"], + "db2": [], + "db3": ["db3:bm1", "db3:bm2"], + "db4": ["db4:bm4"], + } + bmm = bookmark_manager(initial_bookmarks=initial_bookmarks) + await bmm.update_bookmarks( + "db5", ["db3:bm1", "db5:bm1"], ["db3:bm3", "db3:bm5"] + ) + new_bookmarks = await bmm.get_bookmarks("db5") + all_bookmarks = await bmm.get_all_bookmarks() + + assert new_bookmarks == {"db3:bm3", "db3:bm5"} + assert all_bookmarks == set( + itertools.chain.from_iterable(initial_bookmarks.values()) + ) | {"db3:bm3", "db3:bm5"} + + +@pytest.mark.parametrize("with_initial_bookmarks", (True, False)) +@pytest.mark.parametrize("consumer_async", consumer_async_options) +@pytest.mark.parametrize("db", ("db1", "db2", "db3")) +@mark_async_test +async def test_notify_on_new_bookmarks( + with_initial_bookmarks, consumer_async, db, mocker +) -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1", "db1:bm2"], + "db2": ["db2:bm1"], + } + mock_cls = mocker.AsyncMock if consumer_async else mocker.Mock + consumer = mock_cls() + bmm = bookmark_manager( + initial_bookmarks=(initial_bookmarks + if with_initial_bookmarks else None), + bookmarks_consumer=consumer + ) + bookmarks_old = {"db1:bm1", "db3:bm1"} + bookmarks_new = {"db1:bm4"} + await bmm.update_bookmarks(db, bookmarks_old, bookmarks_new) + + if consumer_async: + consumer.assert_awaited_once() + args = consumer.await_args.args + else: + consumer.assert_called_once() + args = consumer.call_args.args + assert args[0] == db + assert isinstance(args[1], Bookmarks) + if with_initial_bookmarks: + expected_bms = ( + set(initial_bookmarks.get(db, [])) - bookmarks_old | bookmarks_new + ) + else: + expected_bms = bookmarks_new + assert args[1].raw_values == expected_bms + + +@pytest.mark.parametrize("consumer_async", consumer_async_options) +@pytest.mark.parametrize("with_initial_bookmarks", (True, False)) +@pytest.mark.parametrize("db", ("db1", "db2")) +@mark_async_test +async def test_does_not_notify_on_empty_new_bookmark_set( + with_initial_bookmarks, consumer_async, db, mocker +) -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm2"] + } + mock_cls = mocker.AsyncMock if consumer_async else mocker.Mock + consumer = mock_cls() + bmm = bookmark_manager( + initial_bookmarks=(initial_bookmarks + if with_initial_bookmarks else None), + bookmarks_consumer=consumer + ) + await bmm.update_bookmarks(db, ["db1:bm1"], []) + + consumer.assert_not_called() + + +@pytest.mark.parametrize("dbs", ( + ["db1"], ["db2"], ["db1", "db2"], ["db1", "db3"], ["db1", "db2", "db3"] +)) +@mark_async_test +async def test_forget_database(dbs) -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1", "db1:bm2"], + "db2": ["db2:bm1"], + } + bmm = bookmark_manager(initial_bookmarks=initial_bookmarks) + + for db in dbs: + assert (await bmm.get_bookmarks(db) + == set(initial_bookmarks.get(db, []))) + + await bmm.forget(dbs) + + # assert the key has been removed (memory optimization) + assert isinstance(bmm, AsyncNeo4jBookmarkManager) + assert (set(bmm._bookmarks.keys()) + == set(initial_bookmarks.keys()) - set(dbs)) + + for db in dbs: + assert await bmm.get_bookmarks(db) == set() + assert await bmm.get_all_bookmarks() == set( + bm for k, v in initial_bookmarks.items() if k not in dbs for bm in v + ) diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index 488c61dd3..2a5acf7b9 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -16,8 +16,10 @@ # limitations under the License. +from __future__ import annotations + import ssl -from functools import wraps +import typing as t import pytest @@ -32,19 +34,18 @@ TrustCustomCAs, TrustSystemCAs, ) -from neo4j._async_compat.util import AsyncUtil from neo4j.api import ( + AsyncBookmarkManager, + BookmarkManager, READ_ACCESS, WRITE_ACCESS, ) from neo4j.exceptions import ConfigurationError -from ..._async_compat import mark_async_test - - -@wraps(AsyncGraphDatabase.driver) -def create_driver(*args, **kwargs): - return AsyncGraphDatabase.driver(*args, **kwargs) +from ..._async_compat import ( + AsyncTestDecorators, + mark_async_test, +) @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @@ -60,7 +61,7 @@ async def test_direct_driver_constructor(protocol, host, port, params, auth_toke with pytest.warns(DeprecationWarning, match="routing context"): driver = AsyncGraphDatabase.driver(uri, auth=auth_token) else: - driver = create_driver(uri, auth=auth_token) + driver = AsyncGraphDatabase.driver(uri, auth=auth_token) assert isinstance(driver, AsyncBoltDriver) await driver.close() @@ -75,7 +76,7 @@ async def test_direct_driver_constructor(protocol, host, port, params, auth_toke @mark_async_test async def test_routing_driver_constructor(protocol, host, port, params, auth_token): uri = protocol + host + port + params - driver = create_driver(uri, auth=auth_token) + driver = AsyncGraphDatabase.driver(uri, auth=auth_token) assert isinstance(driver, AsyncNeo4jDriver) await driver.close() @@ -140,7 +141,7 @@ def driver_builder(): with pytest.warns(DeprecationWarning, match="trust"): return AsyncGraphDatabase.driver(test_uri, **test_config) else: - return create_driver(test_uri, **test_config) + return AsyncGraphDatabase.driver(test_uri, **test_config) if "+" in test_uri: # `+s` and `+ssc` are short hand syntax for not having to configure the @@ -159,7 +160,7 @@ def driver_builder(): )) def test_invalid_protocol(test_uri): with pytest.raises(ConfigurationError, match="scheme"): - create_driver(test_uri) + AsyncGraphDatabase.driver(test_uri) @pytest.mark.parametrize( @@ -174,7 +175,7 @@ def test_driver_trust_config_error( test_config, expected_failure, expected_failure_message ): with pytest.raises(expected_failure, match=expected_failure_message): - create_driver("bolt://127.0.0.1:9001", **test_config) + AsyncGraphDatabase.driver("bolt://127.0.0.1:9001", **test_config) @pytest.mark.parametrize("uri", ( @@ -182,28 +183,24 @@ def test_driver_trust_config_error( "neo4j://127.0.0.1:9000", )) @mark_async_test -async def test_driver_opens_write_session_by_default(uri, mocker): - driver = create_driver(uri) - from neo4j import AsyncTransaction - +async def test_driver_opens_write_session_by_default(uri, fake_pool, mocker): + driver = AsyncGraphDatabase.driver(uri) # we set a specific db, because else the driver would try to fetch a RT # to get hold of the actual home database (which won't work in this # unittest) + driver._pool = fake_pool async with driver.session(database="foobar") as session: - acquire_mock = mocker.patch.object(session._pool, "acquire", - autospec=True) - tx_begin_mock = mocker.patch.object(AsyncTransaction, "_begin", - autospec=True) + mocker.patch("neo4j._async.work.session.AsyncTransaction", + autospec=True) tx = await session.begin_transaction() - acquire_mock.assert_called_once_with( + fake_pool.acquire.assert_awaited_once_with( access_mode=WRITE_ACCESS, timeout=mocker.ANY, database=mocker.ANY, bookmarks=mocker.ANY, liveness_check_timeout=mocker.ANY ) - tx_begin_mock.assert_called_once_with( - tx, + tx._begin.assert_awaited_once_with( mocker.ANY, mocker.ANY, mocker.ANY, @@ -221,7 +218,7 @@ async def test_driver_opens_write_session_by_default(uri, mocker): )) @mark_async_test async def test_verify_connectivity(uri, mocker): - driver = create_driver(uri) + driver = AsyncGraphDatabase.driver(uri) pool_mock = mocker.patch.object(driver, "_pool", autospec=True) try: @@ -248,7 +245,7 @@ async def test_verify_connectivity(uri, mocker): async def test_verify_connectivity_parameters_are_deprecated( uri, kwargs, mocker ): - driver = create_driver(uri) + driver = AsyncGraphDatabase.driver(uri) mocker.patch.object(driver, "_pool", autospec=True) try: @@ -271,7 +268,7 @@ async def test_verify_connectivity_parameters_are_deprecated( async def test_get_server_info_parameters_are_experimental( uri, kwargs, mocker ): - driver = create_driver(uri) + driver = AsyncGraphDatabase.driver(uri) mocker.patch.object(driver, "_pool", autospec=True) try: @@ -279,3 +276,143 @@ async def test_get_server_info_parameters_are_experimental( await driver.get_server_info(**kwargs) finally: await driver.close() + + +@mark_async_test +async def test_with_builtin_bookmark_manager(mocker) -> None: + with pytest.warns(ExperimentalWarning, match="bookmark manager"): + bmm = AsyncGraphDatabase.bookmark_manager() + # could be one line, but want to make sure the type checker assigns + # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns + session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", + autospec=True) + driver = AsyncGraphDatabase.driver("bolt://localhost") + async with driver as driver: + with pytest.warns(ExperimentalWarning, match="bookmark_manager"): + _ = driver.session(bookmark_manager=bmm) + session_cls_mock.assert_called_once() + assert session_cls_mock.call_args[0][1].bookmark_manager is bmm + + +@AsyncTestDecorators.mark_async_only_test +async def test_with_custom_inherited_async_bookmark_manager(mocker) -> None: + class BMM(AsyncBookmarkManager): + async def update_bookmarks( + self, database: str, previous_bookmarks: t.Iterable[str], + new_bookmarks: t.Iterable[str] + ) -> None: + ... + + async def get_bookmarks(self, database: str) -> t.Collection[str]: + ... + + async def get_all_bookmarks(self) -> t.Collection[str]: + ... + + async def forget(self, databases: t.Iterable[str]) -> None: + ... + + bmm = BMM() + # could be one line, but want to make sure the type checker assigns + # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns + session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", + autospec=True) + driver = AsyncGraphDatabase.driver("bolt://localhost") + async with driver as driver: + with pytest.warns(ExperimentalWarning, match="bookmark_manager"): + _ = driver.session(bookmark_manager=bmm) + session_cls_mock.assert_called_once() + assert session_cls_mock.call_args[0][1].bookmark_manager is bmm + + +@mark_async_test +async def test_with_custom_inherited_sync_bookmark_manager(mocker) -> None: + class BMM(BookmarkManager): + def update_bookmarks( + self, database: str, previous_bookmarks: t.Iterable[str], + new_bookmarks: t.Iterable[str] + ) -> None: + ... + + def get_bookmarks(self, database: str) -> t.Collection[str]: + ... + + def get_all_bookmarks(self) -> t.Collection[str]: + ... + + def forget(self, databases: t.Iterable[str]) -> None: + ... + + bmm = BMM() + # could be one line, but want to make sure the type checker assigns + # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns + session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", + autospec=True) + driver = AsyncGraphDatabase.driver("bolt://localhost") + async with driver as driver: + with pytest.warns(ExperimentalWarning, match="bookmark_manager"): + _ = driver.session(bookmark_manager=bmm) + session_cls_mock.assert_called_once() + assert session_cls_mock.call_args[0][1].bookmark_manager is bmm + + +@AsyncTestDecorators.mark_async_only_test +async def test_with_custom_ducktype_async_bookmark_manager(mocker) -> None: + class BMM: + async def update_bookmarks( + self, database: str, previous_bookmarks: t.Iterable[str], + new_bookmarks: t.Iterable[str] + ) -> None: + ... + + async def get_bookmarks(self, database: str) -> t.Collection[str]: + ... + + async def get_all_bookmarks(self) -> t.Collection[str]: + ... + + async def forget(self, databases: t.Iterable[str]) -> None: + ... + + bmm = BMM() + # could be one line, but want to make sure the type checker assigns + # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns + session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", + autospec=True) + driver = AsyncGraphDatabase.driver("bolt://localhost") + async with driver as driver: + with pytest.warns(ExperimentalWarning, match="bookmark_manager"): + _ = driver.session(bookmark_manager=bmm) + session_cls_mock.assert_called_once() + assert session_cls_mock.call_args[0][1].bookmark_manager is bmm + + +@mark_async_test +async def test_with_custom_ducktype_sync_bookmark_manager(mocker) -> None: + class BMM: + def update_bookmarks( + self, database: str, previous_bookmarks: t.Iterable[str], + new_bookmarks: t.Iterable[str] + ) -> None: + ... + + def get_bookmarks(self, database: str) -> t.Collection[str]: + ... + + def get_all_bookmarks(self) -> t.Collection[str]: + ... + + def forget(self, databases: t.Iterable[str]) -> None: + ... + + bmm = BMM() + # could be one line, but want to make sure the type checker assigns + # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns + session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", + autospec=True) + driver = AsyncGraphDatabase.driver("bolt://localhost") + async with driver as driver: + with pytest.warns(ExperimentalWarning, match="bookmark_manager"): + _ = driver.session(bookmark_manager=bmm) + session_cls_mock.assert_called_once() + assert session_cls_mock.call_args[0][1].bookmark_manager is bmm diff --git a/tests/unit/async_/work/__init__.py b/tests/unit/async_/work/__init__.py index d5ea8d85c..c42cc6fb6 100644 --- a/tests/unit/async_/work/__init__.py +++ b/tests/unit/async_/work/__init__.py @@ -14,9 +14,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from ._fake_connection import ( - async_fake_connection, - async_fake_connection_generator, -) diff --git a/tests/unit/async_/work/conftest.py b/tests/unit/async_/work/conftest.py deleted file mode 100644 index 3b60f3efd..000000000 --- a/tests/unit/async_/work/conftest.py +++ /dev/null @@ -1,6 +0,0 @@ -from ._fake_connection import ( - async_fake_connection, - async_fake_connection_generator, - async_scripted_connection, - async_scripted_connection_generator, -) diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index 70d3805b1..b14ffbe67 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -27,27 +27,16 @@ Bookmarks, unit_of_work, ) -from neo4j._async.io._pool import AsyncIOPool +from neo4j._async.io import ( + AsyncBoltPool, + AsyncNeo4jPool, +) from neo4j._conf import SessionConfig +from neo4j.api import AsyncBookmarkManager from ...._async_compat import mark_async_test -@pytest.fixture() -def pool(async_fake_connection_generator, mocker): - pool = mocker.AsyncMock(spec=AsyncIOPool) - assert not hasattr(pool, "acquired_connection_mocks") - pool.acquired_connection_mocks = [] - - def acquire_side_effect(*_, **__): - connection = async_fake_connection_generator() - pool.acquired_connection_mocks.append(connection) - return connection - - pool.acquire.side_effect = acquire_side_effect - return pool - - @mark_async_test async def test_session_context_calls_close(mocker): s = AsyncSession(None, SessionConfig()) @@ -66,9 +55,9 @@ async def test_session_context_calls_close(mocker): )) @mark_async_test async def test_opens_connection_on_run( - pool, test_run_args, repetitions, consume + fake_pool, test_run_args, repetitions, consume ): - async with AsyncSession(pool, SessionConfig()) as session: + async with AsyncSession(fake_pool, SessionConfig()) as session: assert session._connection is None result = await session.run(*test_run_args) assert session._connection is not None @@ -82,9 +71,9 @@ async def test_opens_connection_on_run( @pytest.mark.parametrize("repetitions", range(1, 3)) @mark_async_test async def test_closes_connection_after_consume( - pool, test_run_args, repetitions + fake_pool, test_run_args, repetitions ): - async with AsyncSession(pool, SessionConfig()) as session: + async with AsyncSession(fake_pool, SessionConfig()) as session: result = await session.run(*test_run_args) await result.consume() assert session._connection is None @@ -96,9 +85,9 @@ async def test_closes_connection_after_consume( )) @mark_async_test async def test_keeps_connection_until_last_result_consumed( - pool, test_run_args + fake_pool, test_run_args ): - async with AsyncSession(pool, SessionConfig()) as session: + async with AsyncSession(fake_pool, SessionConfig()) as session: result1 = await session.run(*test_run_args) result2 = await session.run(*test_run_args) assert session._connection is not None @@ -109,8 +98,8 @@ async def test_keeps_connection_until_last_result_consumed( @mark_async_test -async def test_opens_connection_on_tx_begin(pool): - async with AsyncSession(pool, SessionConfig()) as session: +async def test_opens_connection_on_tx_begin(fake_pool): + async with AsyncSession(fake_pool, SessionConfig()) as session: assert session._connection is None async with await session.begin_transaction() as _: assert session._connection is not None @@ -121,8 +110,10 @@ async def test_opens_connection_on_tx_begin(pool): )) @pytest.mark.parametrize("repetitions", range(1, 3)) @mark_async_test -async def test_keeps_connection_on_tx_run(pool, test_run_args, repetitions): - async with AsyncSession(pool, SessionConfig()) as session: +async def test_keeps_connection_on_tx_run( + fake_pool, test_run_args, repetitions +): + async with AsyncSession(fake_pool, SessionConfig()) as session: async with await session.begin_transaction() as tx: for _ in range(repetitions): await tx.run(*test_run_args) @@ -135,9 +126,9 @@ async def test_keeps_connection_on_tx_run(pool, test_run_args, repetitions): @pytest.mark.parametrize("repetitions", range(1, 3)) @mark_async_test async def test_keeps_connection_on_tx_consume( - pool, test_run_args, repetitions + fake_pool, test_run_args, repetitions ): - async with AsyncSession(pool, SessionConfig()) as session: + async with AsyncSession(fake_pool, SessionConfig()) as session: async with await session.begin_transaction() as tx: for _ in range(repetitions): result = await tx.run(*test_run_args) @@ -149,8 +140,8 @@ async def test_keeps_connection_on_tx_consume( ("RETURN $x", {"x": 1}), ("RETURN 1",) )) @mark_async_test -async def test_closes_connection_after_tx_close(pool, test_run_args): - async with AsyncSession(pool, SessionConfig()) as session: +async def test_closes_connection_after_tx_close(fake_pool, test_run_args): + async with AsyncSession(fake_pool, SessionConfig()) as session: async with await session.begin_transaction() as tx: for _ in range(2): result = await tx.run(*test_run_args) @@ -164,8 +155,8 @@ async def test_closes_connection_after_tx_close(pool, test_run_args): ("RETURN $x", {"x": 1}), ("RETURN 1",) )) @mark_async_test -async def test_closes_connection_after_tx_commit(pool, test_run_args): - async with AsyncSession(pool, SessionConfig()) as session: +async def test_closes_connection_after_tx_commit(fake_pool, test_run_args): + async with AsyncSession(fake_pool, SessionConfig()) as session: async with await session.begin_transaction() as tx: for _ in range(2): result = await tx.run(*test_run_args) @@ -180,13 +171,13 @@ async def test_closes_connection_after_tx_commit(pool, test_run_args): (None, [], ["abc"], ["foo", "bar"], {"a", "b"}, ("1", "two")) ) @mark_async_test -async def test_session_returns_bookmarks_directly(pool, bookmark_values): +async def test_session_returns_bookmarks_directly(fake_pool, bookmark_values): if bookmark_values is not None: bookmarks = Bookmarks.from_raw_values(bookmark_values) else: bookmarks = Bookmarks() async with AsyncSession( - pool, SessionConfig(bookmarks=bookmarks) + fake_pool, SessionConfig(bookmarks=bookmarks) ) as session: ret_bookmarks = (await session.last_bookmarks()) assert isinstance(ret_bookmarks, Bookmarks) @@ -202,12 +193,13 @@ async def test_session_returns_bookmarks_directly(pool, bookmark_values): (None, [], ["abc"], ["foo", "bar"], ("1", "two")) ) @mark_async_test -async def test_session_last_bookmark_is_deprecated(pool, bookmarks): +async def test_session_last_bookmark_is_deprecated(fake_pool, bookmarks): if bookmarks is not None: with pytest.warns(DeprecationWarning): - session = AsyncSession(pool, SessionConfig(bookmarks=bookmarks)) + session = AsyncSession(fake_pool, + SessionConfig(bookmarks=bookmarks)) else: - session = AsyncSession(pool, SessionConfig(bookmarks=bookmarks)) + session = AsyncSession(fake_pool, SessionConfig(bookmarks=bookmarks)) async with session: with pytest.warns(DeprecationWarning): if bookmarks: @@ -221,9 +213,11 @@ async def test_session_last_bookmark_is_deprecated(pool, bookmarks): (("foo",), ("foo", "bar"), (), ["foo", "bar"], {"a", "b"}) ) @mark_async_test -async def test_session_bookmarks_as_iterable_is_deprecated(pool, bookmarks): +async def test_session_bookmarks_as_iterable_is_deprecated( + fake_pool, bookmarks +): with pytest.warns(DeprecationWarning): - async with AsyncSession(pool, SessionConfig( + async with AsyncSession(fake_pool, SessionConfig( bookmarks=bookmarks )) as session: ret_bookmarks = (await session.last_bookmarks()).raw_values @@ -237,19 +231,19 @@ async def test_session_bookmarks_as_iterable_is_deprecated(pool, bookmarks): (["I don't", "think so"], TypeError), )) @mark_async_test -async def test_session_run_wrong_types(pool, query, error_type): - async with AsyncSession(pool, SessionConfig()) as session: +async def test_session_run_wrong_types(fake_pool, query, error_type): + async with AsyncSession(fake_pool, SessionConfig()) as session: with pytest.raises(error_type): await session.run(query) @pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction")) @mark_async_test -async def test_tx_function_argument_type(pool, tx_type): +async def test_tx_function_argument_type(fake_pool, tx_type): async def work(tx): assert isinstance(tx, AsyncManagedTransaction) - async with AsyncSession(pool, SessionConfig()) as session: + async with AsyncSession(fake_pool, SessionConfig()) as session: await getattr(session, tx_type)(work) @@ -262,18 +256,20 @@ async def work(tx): )) @mark_async_test -async def test_decorated_tx_function_argument_type(pool, tx_type, decorator_kwargs): +async def test_decorated_tx_function_argument_type( + fake_pool, tx_type, decorator_kwargs +): @unit_of_work(**decorator_kwargs) async def work(tx): assert isinstance(tx, AsyncManagedTransaction) - async with AsyncSession(pool, SessionConfig()) as session: + async with AsyncSession(fake_pool, SessionConfig()) as session: await getattr(session, tx_type)(work) @mark_async_test -async def test_session_tx_type(pool): - async with AsyncSession(pool, SessionConfig()) as session: +async def test_session_tx_type(fake_pool): + async with AsyncSession(fake_pool, SessionConfig()) as session: tx = await session.begin_transaction() assert isinstance(tx, AsyncTransaction) @@ -300,9 +296,9 @@ async def test_session_tx_type(pool): @pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed")) @mark_async_test async def test_session_run_with_parameters( - pool, parameters, run_type, mocker + fake_pool, parameters, run_type, mocker ): - async with AsyncSession(pool, SessionConfig()) as session: + async with AsyncSession(fake_pool, SessionConfig()) as session: if run_type == "auto": await session.run("RETURN $x", **parameters) elif run_type == "unmanaged": @@ -315,9 +311,165 @@ async def work(tx): else: raise ValueError(run_type) - assert len(pool.acquired_connection_mocks) == 1 - connection_mock = pool.acquired_connection_mocks[0] + assert len(fake_pool.acquired_connection_mocks) == 1 + connection_mock = fake_pool.acquired_connection_mocks[0] assert connection_mock.run.called_once() call = connection_mock.run.call_args assert call.args[0] == "RETURN $x" assert call.kwargs["parameters"] == parameters + + +@pytest.mark.parametrize("db", (None, "adb")) +@pytest.mark.parametrize("routing", (True, False)) +# no home db resolution when connected to Neo4j 4.3 or earlier +@pytest.mark.parametrize("home_db_gets_resolved", (True, False)) +@pytest.mark.parametrize("additional_session_bookmarks", + (None, ["session", "bookmarks"])) +@mark_async_test +async def test_with_bookmark_manager( + fake_pool, db, routing, async_scripted_connection, home_db_gets_resolved, + additional_session_bookmarks, mocker +): + async def update_routing_table_side_effect( + database, imp_user, bookmarks, acquisition_timeout=None, + database_callback=None + ): + if home_db_gets_resolved: + database_callback("homedb") + + async def bmm_get_bookmarks(database): + return [f"{database}:bm1"] + + async def bmm_gat_all_bookmarks(): + return ["all", "bookmarks"] + + async_scripted_connection.set_script([ + ("run", {"on_success": None, "on_summary": None}), + ("pull", { + "on_success": ({"bookmark": "res:bm1", "has_more": False},), + "on_summary": None, + "on_records": None, + }) + ]) + fake_pool.buffered_connection_mocks.append(async_scripted_connection) + + bmm = mocker.Mock(spec=AsyncBookmarkManager) + bmm.get_bookmarks.side_effect = bmm_get_bookmarks + bmm.get_all_bookmarks.side_effect = bmm_gat_all_bookmarks + + if routing: + fake_pool.mock_add_spec(AsyncNeo4jPool) + fake_pool.update_routing_table.side_effect = \ + update_routing_table_side_effect + else: + fake_pool.mock_add_spec(AsyncBoltPool) + + config = SessionConfig() + config.bookmark_manager = bmm + if db is not None: + config.database = db + if additional_session_bookmarks: + config.bookmarks = Bookmarks.from_raw_values( + additional_session_bookmarks + ) + async with AsyncSession(fake_pool, config) as session: + assert not bmm.method_calls + + await session.run("RETURN 1") + + # assert called bmm accordingly + expected_bmm_method_calls = [mocker.call.get_bookmarks("system"), + mocker.call.get_all_bookmarks()] + if routing and db is None: + expected_bmm_method_calls = [ + # extra call for resolving the home database + mocker.call.get_bookmarks("system"), + *expected_bmm_method_calls + ] + assert bmm.method_calls == expected_bmm_method_calls + assert (bmm.get_bookmarks.await_count + == len(expected_bmm_method_calls) - 1) + bmm.get_all_bookmarks.assert_awaited_once() + bmm.method_calls.clear() + + expected_update_for_db = db + if not db: + if home_db_gets_resolved and routing: + expected_update_for_db = "homedb" + else: + expected_update_for_db = "" + assert [call[0] for call in bmm.method_calls] == ["update_bookmarks"] + assert bmm.method_calls[0].kwargs == {} + assert len(bmm.method_calls[0].args) == 3 + assert bmm.method_calls[0].args[0] == expected_update_for_db + assert (set(bmm.method_calls[0].args[1]) + == {"all", "bookmarks", *(additional_session_bookmarks or [])}) + assert set(bmm.method_calls[0].args[2]) == {"res:bm1"} + + expected_pool_method_calls = ["acquire", "release"] + if routing and db is None: + expected_pool_method_calls = ["update_routing_table", + *expected_pool_method_calls] + assert ([call[0] for call in fake_pool.method_calls] + == expected_pool_method_calls) + assert (set(fake_pool.acquire.call_args.kwargs["bookmarks"]) + == {"system:bm1", *(additional_session_bookmarks or [])}) + if routing and db is None: + assert ( + set(fake_pool.update_routing_table.call_args.kwargs["bookmarks"]) + == {"system:bm1", *(additional_session_bookmarks or [])} + ) + + assert len(fake_pool.acquired_connection_mocks) == 1 + connection_mock = fake_pool.acquired_connection_mocks[0] + assert connection_mock.run.called_once() + connection_run_call_kwargs = connection_mock.run.call_args.kwargs + assert (set(connection_run_call_kwargs["bookmarks"]) + == {"all", "bookmarks", *(additional_session_bookmarks or [])}) + + +@pytest.mark.parametrize("routing", (True, False)) +@pytest.mark.parametrize("session_method", ("run", "get_server_info")) +@mark_async_test +async def test_last_bookmarks_do_not_leak_bookmark_managers_bookmarks( + fake_pool, routing, session_method, mocker +): + async def bmm_get_bookmarks(database): + return [f"bmm:{database}"] + + async def bmm_gat_all_bookmarks(): + return ["bmm:all", "bookmarks"] + + fake_pool.mock_add_spec(AsyncNeo4jPool if routing else AsyncBoltPool) + + bmm = mocker.Mock(spec=AsyncBookmarkManager) + bmm.get_bookmarks.side_effect = bmm_get_bookmarks + bmm.get_all_bookmarks.side_effect = bmm_gat_all_bookmarks + + config = SessionConfig() + config.bookmark_manager = bmm + config.bookmarks = Bookmarks.from_raw_values(["session", "bookmarks"]) + async with AsyncSession(fake_pool, config) as session: + if session_method == "run": + await session.run("RETURN 1") + elif session_method == "get_server_info": + await session._get_server_info() + else: + assert False + last_bookmarks = await session.last_bookmarks() + + assert last_bookmarks.raw_values == {"session", "bookmarks"} + assert last_bookmarks.raw_values == {"session", "bookmarks"} + + +@mark_async_test +async def test_with_ignored_bookmark_manager(fake_pool, mocker): + bmm = mocker.Mock(spec=AsyncBookmarkManager) + session_config = SessionConfig() + session_config.bookmark_manager = bmm + session_config.ignore_bookmark_manager = True + async with AsyncSession(fake_pool, session_config) as session: + await session.run("RETURN 1") + + bmm.assert_not_called() + assert not bmm.method_calls diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index 72ee07f2d..8c6601ad5 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -19,6 +19,7 @@ import pytest from neo4j import ( + ExperimentalWarning, TrustAll, TrustCustomCAs, TrustSystemCAs, @@ -66,6 +67,8 @@ "database": None, "impersonated_user": None, "fetch_size": 100, + "bookmark_manager": object(), + "ignore_bookmark_manager": False, } @@ -182,6 +185,14 @@ def test_pool_config_deprecated_and_new_trust_config(value_trust, "trusted_certificates": trusted_certificates}) +@pytest.mark.parametrize("config_cls", (WorkspaceConfig, SessionConfig)) +def test_bookmark_manager_is_experimental(config_cls): + bmm = object() + with pytest.warns(ExperimentalWarning, match="bookmark_manager"): + config = config_cls.consume({"bookmark_manager": bmm}) + assert config.bookmark_manager is bmm + + def test_config_consume_chain(): test_config = {} @@ -190,7 +201,10 @@ def test_config_consume_chain(): test_config.update(test_session_config) - consumed_pool_config, consumed_session_config = Config.consume_chain(test_config, PoolConfig, SessionConfig) + with pytest.warns(ExperimentalWarning, match="bookmark_manager"): + consumed_pool_config, consumed_session_config = Config.consume_chain( + test_config, PoolConfig, SessionConfig + ) assert isinstance(consumed_pool_config, PoolConfig) assert isinstance(consumed_session_config, SessionConfig) diff --git a/tests/unit/sync/conftest.py b/tests/unit/sync/conftest.py new file mode 100644 index 000000000..9d171987d --- /dev/null +++ b/tests/unit/sync/conftest.py @@ -0,0 +1,19 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .fixtures import * # necessary for pytest to discover the fixtures diff --git a/tests/unit/sync/fixtures/__init__.py b/tests/unit/sync/fixtures/__init__.py new file mode 100644 index 000000000..c3e907d44 --- /dev/null +++ b/tests/unit/sync/fixtures/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .fake_connection import * +from .fake_pool import * diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py new file mode 100644 index 000000000..f3a8e695a --- /dev/null +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -0,0 +1,215 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect + +import pytest + +from neo4j import ServerInfo +from neo4j._deadline import Deadline +from neo4j._sync.io import Bolt + + +__all__ = [ + "fake_connection_generator", + "fake_connection", + "scripted_connection_generator", + "scripted_connection", +] + + +@pytest.fixture +def fake_connection_generator(session_mocker): + mock = session_mocker.mock_module + + class FakeConnection(mock.NonCallableMagicMock): + callbacks = [] + server_info = ServerInfo("127.0.0.1", (4, 3)) + local_port = 1234 + + def __init__(self, *args, **kwargs): + kwargs["spec"] = Bolt + super().__init__(*args, **kwargs) + self.attach_mock(mock.Mock(return_value=True), "is_reset_mock") + self.attach_mock(mock.Mock(return_value=False), "defunct") + self.attach_mock(mock.Mock(return_value=False), "stale") + self.attach_mock(mock.Mock(return_value=False), "closed") + self.attach_mock(mock.Mock(return_value=False), "socket") + self.attach_mock(mock.Mock(), "unresolved_address") + + def close_side_effect(): + self.closed.return_value = True + + self.attach_mock(mock.Mock(side_effect=close_side_effect), + "close") + + self.socket.attach_mock( + mock.Mock(return_value=None), "get_deadline" + ) + + def set_deadline_side_effect(deadline): + deadline = Deadline.from_timeout_or_deadline(deadline) + self.socket.get_deadline.return_value = deadline + + self.socket.attach_mock( + mock.Mock(side_effect=set_deadline_side_effect), "set_deadline" + ) + + @property + def is_reset(self): + if self.closed.return_value or self.defunct.return_value: + raise AssertionError( + "is_reset should not be called on a closed or defunct " + "connection." + ) + return self.is_reset_mock() + + def fetch_message(self, *args, **kwargs): + if self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_message")(*args, **kwargs) + + def fetch_all(self, *args, **kwargs): + while self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_all")(*args, **kwargs) + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + def func(*args, **kwargs): + def callback(): + for cb_name, param_count in ( + ("on_success", 1), + ("on_summary", 0) + ): + cb = kwargs.get(cb_name, None) + if callable(cb): + try: + param_count = \ + len(inspect.signature(cb).parameters) + except ValueError: + # e.g. built-in method as cb + pass + if param_count == 1: + res = cb({}) + else: + res = cb() + try: + res # maybe the callback is async + except TypeError: + pass # or maybe it wasn't ;) + + self.callbacks.append(callback) + + return func + + method_mock = parent.__getattr__(name) + if name in ("run", "commit", "pull", "rollback", "discard"): + method_mock.side_effect = build_message_handler(name) + return method_mock + + return FakeConnection + + +@pytest.fixture +def fake_connection(fake_connection_generator): + return fake_connection_generator() + + +@pytest.fixture +def scripted_connection_generator(fake_connection_generator): + class ScriptedConnection(fake_connection_generator): + _script = [] + _script_pos = 0 + + def set_script(self, callbacks): + """Set a scripted sequence of callbacks. + + :param callbacks: The callbacks. They should be a list of 2-tuples. + `("name_of_message", {"callback_name": arguments})`. E.g., + ``` + [ + ("run", {"on_success": ({},), "on_summary": None}), + ("pull", { + "on_success": None, + "on_summary": None, + "on_records": + }) + ] + ``` + Note that arguments can be `None`. In this case, ScriptedConnection + will make a guess on best-suited default arguments. + """ + self._script = callbacks + self._script_pos = 0 + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + def func(*args, **kwargs): + try: + expected_message, scripted_callbacks = \ + self._script[self._script_pos] + except IndexError: + pytest.fail("End of scripted connection reached.") + assert name == expected_message + self._script_pos += 1 + + def callback(): + for cb_name, default_cb_args in ( + ("on_ignored", ({},)), + ("on_failure", ({},)), + ("on_records", ([],)), + ("on_success", ({},)), + ("on_summary", ()), + ): + cb = kwargs.get(cb_name, None) + if ( + not callable(cb) + or cb_name not in scripted_callbacks + ): + continue + cb_args = scripted_callbacks[cb_name] + if cb_args is None: + cb_args = default_cb_args + res = cb(*cb_args) + try: + res # maybe the callback is async + except TypeError: + pass # or maybe it wasn't ;) + + self.callbacks.append(callback) + + return func + + method_mock = parent.__getattr__(name) + if name in ("run", "commit", "pull", "rollback", "discard"): + method_mock.side_effect = build_message_handler(name) + return method_mock + + return ScriptedConnection + + +@pytest.fixture +def scripted_connection(scripted_connection_generator): + return scripted_connection_generator() diff --git a/tests/unit/sync/fixtures/fake_pool.py b/tests/unit/sync/fixtures/fake_pool.py new file mode 100644 index 000000000..63f5853ed --- /dev/null +++ b/tests/unit/sync/fixtures/fake_pool.py @@ -0,0 +1,45 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._sync.io._pool import IOPool + + +__all__ = [ + "fake_pool", +] + + +@pytest.fixture +def fake_pool(fake_connection_generator, mocker): + pool = mocker.Mock(spec=IOPool) + assert not hasattr(pool, "acquired_connection_mocks") + pool.buffered_connection_mocks = [] + pool.acquired_connection_mocks = [] + + def acquire_side_effect(*_, **__): + if pool.buffered_connection_mocks: + connection = pool.buffered_connection_mocks.pop() + else: + connection = fake_connection_generator() + pool.acquired_connection_mocks.append(connection) + return connection + + pool.acquire.side_effect = acquire_side_effect + return pool diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 564660966..d15bed04f 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -244,7 +244,7 @@ def test_hint_recv_timeout_seconds( sockets = fake_socket_pair(address, packer_cls=Bolt4x4.PACKER_CLS, unpacker_cls=Bolt4x4.UNPACKER_CLS) - sockets.client.settimeout = mocker.MagicMock() + sockets.client.settimeout = mocker.Mock() sockets.server.send_message( b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 6b9f8d7eb..bca9d4441 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -38,7 +38,6 @@ ) from ...._async_compat import mark_sync_test -from ..work import fake_connection_generator # needed as fixture ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") diff --git a/tests/unit/sync/test_addressing.py b/tests/unit/sync/test_addressing.py index 2779eb8bf..444b289d4 100644 --- a/tests/unit/sync/test_addressing.py +++ b/tests/unit/sync/test_addressing.py @@ -34,7 +34,7 @@ @mark_sync_test -def test_address_resolve(): +def test_address_resolve() -> None: address = Address(("127.0.0.1", 7687)) resolved = NetworkUtil.resolve_address(address) resolved = Util.list(resolved) @@ -45,7 +45,7 @@ def test_address_resolve(): @mark_sync_test -def test_address_resolve_with_custom_resolver_none(): +def test_address_resolve_with_custom_resolver_none() -> None: address = Address(("127.0.0.1", 7687)) resolved = NetworkUtil.resolve_address(address, resolver=None) resolved = Util.list(resolved) @@ -64,7 +64,9 @@ def test_address_resolve_with_custom_resolver_none(): ) @mark_sync_test -def test_address_resolve_with_unresolvable_address(test_input, expected): +def test_address_resolve_with_unresolvable_address( + test_input, expected +) -> None: with pytest.raises(expected): Util.list( NetworkUtil.resolve_address(test_input, resolver=None) @@ -73,7 +75,7 @@ def test_address_resolve_with_unresolvable_address(test_input, expected): @mark_sync_test @pytest.mark.parametrize("resolver_type", ("sync", "async")) -def test_address_resolve_with_custom_resolver(resolver_type): +def test_address_resolve_with_custom_resolver(resolver_type) -> None: def custom_resolver_sync(_): return [("127.0.0.1", 7687), ("localhost", 1234)] @@ -98,9 +100,11 @@ def custom_resolver_async(_): @mark_sync_test -def test_address_unresolve(): +def test_address_unresolve() -> None: + def custom_resolver(_): + return custom_resolved + custom_resolved = [("127.0.0.1", 7687), ("localhost", 4321)] - custom_resolver = lambda _: custom_resolved address = Address(("foobar", 1234)) unresolved = address.unresolved @@ -109,9 +113,9 @@ def test_address_unresolve(): resolved = NetworkUtil.resolve_address( address, family=AF_INET, resolver=custom_resolver ) - resolved = Util.list(resolved) - custom_resolved = sorted(Address(a) for a in custom_resolved) - unresolved = sorted(a.unresolved for a in resolved) - assert custom_resolved == unresolved - assert (list(map(lambda a: a.__class__, custom_resolved)) - == list(map(lambda a: a.__class__, unresolved))) + resolved_list = Util.list(resolved) + custom_resolved_addresses = sorted(Address(a) for a in custom_resolved) + unresolved_list = sorted(a.unresolved for a in resolved_list) + assert custom_resolved_addresses == unresolved_list + assert (list(map(lambda a: a.__class__, custom_resolved_addresses)) + == list(map(lambda a: a.__class__, unresolved_list))) diff --git a/tests/unit/sync/test_bookmark_manager.py b/tests/unit/sync/test_bookmark_manager.py new file mode 100644 index 000000000..e7bf9f166 --- /dev/null +++ b/tests/unit/sync/test_bookmark_manager.py @@ -0,0 +1,284 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import itertools +import typing as t + +import pytest + +import neo4j +from neo4j._async_compat.util import Util +from neo4j._meta import copy_signature +from neo4j._sync.bookmark_manager import Neo4jBookmarkManager +from neo4j.api import Bookmarks + +from ..._async_compat import mark_sync_test + + +supplier_async_options = (True, False) if Util.is_async_code else (False,) +consumer_async_options = supplier_async_options + + +@copy_signature(neo4j.GraphDatabase.bookmark_manager) +def bookmark_manager(*args, **kwargs): + with pytest.warns(neo4j.ExperimentalWarning, match="bookmark manager"): + return neo4j.GraphDatabase.bookmark_manager(*args, **kwargs) + + +@pytest.mark.parametrize("db", ("foobar", "system")) +@mark_sync_test +def test_return_empty_if_db_doesnt_exists(db) -> None: + bmm = bookmark_manager() + + assert set(bmm.get_bookmarks(db)) == set() + + +@pytest.mark.parametrize("db", ("db1", "db2", "db3")) +@mark_sync_test +def test_return_initial_bookmarks_for_the_given_db(db) -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1"], + "db2": [], + "db3": ["db3:bm1", "db3:bm2"], + "db4": ["db4:bm4"] + } + bmm = bookmark_manager(initial_bookmarks=initial_bookmarks) + + assert set(bmm.get_bookmarks(db)) == set(initial_bookmarks[db]) + + +@pytest.mark.parametrize("db", ("db1", "db2", "db3")) +@pytest.mark.parametrize("supplier_async", supplier_async_options) +@mark_sync_test +def test_return_get_bookmarks_from_bookmarks_supplier( + db, mocker, supplier_async +) -> None: + extra_bookmarks = ["foo:bm1", "bar:bm2", "foo:bm1"] + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1"], + "db2": [], + "db3": ["db3:bm1", "db3:bm2"], + "db4": ["db4:bm4"] + } + mock_cls = mocker.Mock if supplier_async else mocker.Mock + supplier = mock_cls( + return_value=Bookmarks.from_raw_values(extra_bookmarks) + ) + bmm = bookmark_manager(initial_bookmarks=initial_bookmarks, + bookmarks_supplier=supplier) + + assert set(bmm.get_bookmarks(db)) == { + *extra_bookmarks, *initial_bookmarks.get(db, []) + } + if supplier_async: + supplier.assert_called_once_with(db) + else: + supplier.assert_called_once_with(db) + + +@pytest.mark.parametrize("with_initial_bookmarks", (True, False)) +@mark_sync_test +def test_return_all_bookmarks(with_initial_bookmarks) -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1"], + "db2": [], + "db3": ["db3:bm1", "db3:bm2"], + "db4": ["db4:bm4"], + "db5": ["db3:bm1"] + } + bmm = bookmark_manager( + initial_bookmarks=initial_bookmarks if with_initial_bookmarks else None + ) + + all_bookmarks = bmm.get_all_bookmarks() + + if with_initial_bookmarks: + assert all_bookmarks == set( + itertools.chain.from_iterable(initial_bookmarks.values()) + ) + else: + assert all_bookmarks == set() + + +@pytest.mark.parametrize("with_initial_bookmarks", (True, False)) +@pytest.mark.parametrize("supplier_async", supplier_async_options) +@mark_sync_test +def test_return_enriched_bookmarks_list_with_supplied_bookmarks( + with_initial_bookmarks, supplier_async, mocker +) -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1"], + "db2": [], + "db3": ["db3:bm1", "db3:bm2"], + "db4": ["db4:bm4"], + } + extra_bookmarks = ["foo:bm1", "bar:bm2", "db3:bm1", "foo:bm1"] + mock_cls = mocker.Mock if supplier_async else mocker.Mock + supplier = mock_cls( + return_value=Bookmarks.from_raw_values(extra_bookmarks) + ) + bmm = bookmark_manager( + initial_bookmarks=(initial_bookmarks + if with_initial_bookmarks else None), + bookmarks_supplier=supplier + ) + + all_bookmarks = bmm.get_all_bookmarks() + + if with_initial_bookmarks: + assert all_bookmarks == set( + itertools.chain(*initial_bookmarks.values(), extra_bookmarks) + ) + else: + assert all_bookmarks == set(extra_bookmarks) + if supplier_async: + supplier.assert_called_once_with(None) + else: + supplier.assert_called_once_with(None) + + +@mark_sync_test +def test_chains_bookmarks_for_existing_db() -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1"], + "db2": [], + "db3": ["db3:bm1", "db3:bm2"], + "db4": ["db4:bm4"], + } + bmm = bookmark_manager(initial_bookmarks=initial_bookmarks) + bmm.update_bookmarks("db3", ["db3:bm1"], ["db3:bm3"]) + new_bookmarks = bmm.get_bookmarks("db3") + all_bookmarks = bmm.get_all_bookmarks() + + assert new_bookmarks == {"db3:bm2", "db3:bm3"} + assert all_bookmarks == set( + itertools.chain.from_iterable(initial_bookmarks.values()) + ) - {"db3:bm1"} | {"db3:bm2", "db3:bm3"} + + +@mark_sync_test +def test_add_bookmarks_for_a_non_existing_database() -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1"], + "db2": [], + "db3": ["db3:bm1", "db3:bm2"], + "db4": ["db4:bm4"], + } + bmm = bookmark_manager(initial_bookmarks=initial_bookmarks) + bmm.update_bookmarks( + "db5", ["db3:bm1", "db5:bm1"], ["db3:bm3", "db3:bm5"] + ) + new_bookmarks = bmm.get_bookmarks("db5") + all_bookmarks = bmm.get_all_bookmarks() + + assert new_bookmarks == {"db3:bm3", "db3:bm5"} + assert all_bookmarks == set( + itertools.chain.from_iterable(initial_bookmarks.values()) + ) | {"db3:bm3", "db3:bm5"} + + +@pytest.mark.parametrize("with_initial_bookmarks", (True, False)) +@pytest.mark.parametrize("consumer_async", consumer_async_options) +@pytest.mark.parametrize("db", ("db1", "db2", "db3")) +@mark_sync_test +def test_notify_on_new_bookmarks( + with_initial_bookmarks, consumer_async, db, mocker +) -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1", "db1:bm2"], + "db2": ["db2:bm1"], + } + mock_cls = mocker.Mock if consumer_async else mocker.Mock + consumer = mock_cls() + bmm = bookmark_manager( + initial_bookmarks=(initial_bookmarks + if with_initial_bookmarks else None), + bookmarks_consumer=consumer + ) + bookmarks_old = {"db1:bm1", "db3:bm1"} + bookmarks_new = {"db1:bm4"} + bmm.update_bookmarks(db, bookmarks_old, bookmarks_new) + + if consumer_async: + consumer.assert_called_once() + args = consumer.await_args.args + else: + consumer.assert_called_once() + args = consumer.call_args.args + assert args[0] == db + assert isinstance(args[1], Bookmarks) + if with_initial_bookmarks: + expected_bms = ( + set(initial_bookmarks.get(db, [])) - bookmarks_old | bookmarks_new + ) + else: + expected_bms = bookmarks_new + assert args[1].raw_values == expected_bms + + +@pytest.mark.parametrize("consumer_async", consumer_async_options) +@pytest.mark.parametrize("with_initial_bookmarks", (True, False)) +@pytest.mark.parametrize("db", ("db1", "db2")) +@mark_sync_test +def test_does_not_notify_on_empty_new_bookmark_set( + with_initial_bookmarks, consumer_async, db, mocker +) -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm2"] + } + mock_cls = mocker.Mock if consumer_async else mocker.Mock + consumer = mock_cls() + bmm = bookmark_manager( + initial_bookmarks=(initial_bookmarks + if with_initial_bookmarks else None), + bookmarks_consumer=consumer + ) + bmm.update_bookmarks(db, ["db1:bm1"], []) + + consumer.assert_not_called() + + +@pytest.mark.parametrize("dbs", ( + ["db1"], ["db2"], ["db1", "db2"], ["db1", "db3"], ["db1", "db2", "db3"] +)) +@mark_sync_test +def test_forget_database(dbs) -> None: + initial_bookmarks: t.Dict[str, t.List[str]] = { + "db1": ["db1:bm1", "db1:bm1", "db1:bm2"], + "db2": ["db2:bm1"], + } + bmm = bookmark_manager(initial_bookmarks=initial_bookmarks) + + for db in dbs: + assert (bmm.get_bookmarks(db) + == set(initial_bookmarks.get(db, []))) + + bmm.forget(dbs) + + # assert the key has been removed (memory optimization) + assert isinstance(bmm, Neo4jBookmarkManager) + assert (set(bmm._bookmarks.keys()) + == set(initial_bookmarks.keys()) - set(dbs)) + + for db in dbs: + assert bmm.get_bookmarks(db) == set() + assert bmm.get_all_bookmarks() == set( + bm for k, v in initial_bookmarks.items() if k not in dbs for bm in v + ) diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index 07a85a841..34e38d80b 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -16,8 +16,10 @@ # limitations under the License. +from __future__ import annotations + import ssl -from functools import wraps +import typing as t import pytest @@ -32,19 +34,17 @@ TrustCustomCAs, TrustSystemCAs, ) -from neo4j._async_compat.util import Util from neo4j.api import ( + BookmarkManager, READ_ACCESS, WRITE_ACCESS, ) from neo4j.exceptions import ConfigurationError -from ..._async_compat import mark_sync_test - - -@wraps(GraphDatabase.driver) -def create_driver(*args, **kwargs): - return GraphDatabase.driver(*args, **kwargs) +from ..._async_compat import ( + mark_sync_test, + TestDecorators, +) @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @@ -60,7 +60,7 @@ def test_direct_driver_constructor(protocol, host, port, params, auth_token): with pytest.warns(DeprecationWarning, match="routing context"): driver = GraphDatabase.driver(uri, auth=auth_token) else: - driver = create_driver(uri, auth=auth_token) + driver = GraphDatabase.driver(uri, auth=auth_token) assert isinstance(driver, BoltDriver) driver.close() @@ -75,7 +75,7 @@ def test_direct_driver_constructor(protocol, host, port, params, auth_token): @mark_sync_test def test_routing_driver_constructor(protocol, host, port, params, auth_token): uri = protocol + host + port + params - driver = create_driver(uri, auth=auth_token) + driver = GraphDatabase.driver(uri, auth=auth_token) assert isinstance(driver, Neo4jDriver) driver.close() @@ -140,7 +140,7 @@ def driver_builder(): with pytest.warns(DeprecationWarning, match="trust"): return GraphDatabase.driver(test_uri, **test_config) else: - return create_driver(test_uri, **test_config) + return GraphDatabase.driver(test_uri, **test_config) if "+" in test_uri: # `+s` and `+ssc` are short hand syntax for not having to configure the @@ -159,7 +159,7 @@ def driver_builder(): )) def test_invalid_protocol(test_uri): with pytest.raises(ConfigurationError, match="scheme"): - create_driver(test_uri) + GraphDatabase.driver(test_uri) @pytest.mark.parametrize( @@ -174,7 +174,7 @@ def test_driver_trust_config_error( test_config, expected_failure, expected_failure_message ): with pytest.raises(expected_failure, match=expected_failure_message): - create_driver("bolt://127.0.0.1:9001", **test_config) + GraphDatabase.driver("bolt://127.0.0.1:9001", **test_config) @pytest.mark.parametrize("uri", ( @@ -182,28 +182,24 @@ def test_driver_trust_config_error( "neo4j://127.0.0.1:9000", )) @mark_sync_test -def test_driver_opens_write_session_by_default(uri, mocker): - driver = create_driver(uri) - from neo4j import Transaction - +def test_driver_opens_write_session_by_default(uri, fake_pool, mocker): + driver = GraphDatabase.driver(uri) # we set a specific db, because else the driver would try to fetch a RT # to get hold of the actual home database (which won't work in this # unittest) + driver._pool = fake_pool with driver.session(database="foobar") as session: - acquire_mock = mocker.patch.object(session._pool, "acquire", - autospec=True) - tx_begin_mock = mocker.patch.object(Transaction, "_begin", - autospec=True) + mocker.patch("neo4j._sync.work.session.Transaction", + autospec=True) tx = session.begin_transaction() - acquire_mock.assert_called_once_with( + fake_pool.acquire.assert_called_once_with( access_mode=WRITE_ACCESS, timeout=mocker.ANY, database=mocker.ANY, bookmarks=mocker.ANY, liveness_check_timeout=mocker.ANY ) - tx_begin_mock.assert_called_once_with( - tx, + tx._begin.assert_called_once_with( mocker.ANY, mocker.ANY, mocker.ANY, @@ -221,7 +217,7 @@ def test_driver_opens_write_session_by_default(uri, mocker): )) @mark_sync_test def test_verify_connectivity(uri, mocker): - driver = create_driver(uri) + driver = GraphDatabase.driver(uri) pool_mock = mocker.patch.object(driver, "_pool", autospec=True) try: @@ -248,7 +244,7 @@ def test_verify_connectivity(uri, mocker): def test_verify_connectivity_parameters_are_deprecated( uri, kwargs, mocker ): - driver = create_driver(uri) + driver = GraphDatabase.driver(uri) mocker.patch.object(driver, "_pool", autospec=True) try: @@ -271,7 +267,7 @@ def test_verify_connectivity_parameters_are_deprecated( def test_get_server_info_parameters_are_experimental( uri, kwargs, mocker ): - driver = create_driver(uri) + driver = GraphDatabase.driver(uri) mocker.patch.object(driver, "_pool", autospec=True) try: @@ -279,3 +275,143 @@ def test_get_server_info_parameters_are_experimental( driver.get_server_info(**kwargs) finally: driver.close() + + +@mark_sync_test +def test_with_builtin_bookmark_manager(mocker) -> None: + with pytest.warns(ExperimentalWarning, match="bookmark manager"): + bmm = GraphDatabase.bookmark_manager() + # could be one line, but want to make sure the type checker assigns + # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns + session_cls_mock = mocker.patch("neo4j._sync.driver.Session", + autospec=True) + driver = GraphDatabase.driver("bolt://localhost") + with driver as driver: + with pytest.warns(ExperimentalWarning, match="bookmark_manager"): + _ = driver.session(bookmark_manager=bmm) + session_cls_mock.assert_called_once() + assert session_cls_mock.call_args[0][1].bookmark_manager is bmm + + +@TestDecorators.mark_async_only_test +def test_with_custom_inherited_async_bookmark_manager(mocker) -> None: + class BMM(BookmarkManager): + def update_bookmarks( + self, database: str, previous_bookmarks: t.Iterable[str], + new_bookmarks: t.Iterable[str] + ) -> None: + ... + + def get_bookmarks(self, database: str) -> t.Collection[str]: + ... + + def get_all_bookmarks(self) -> t.Collection[str]: + ... + + def forget(self, databases: t.Iterable[str]) -> None: + ... + + bmm = BMM() + # could be one line, but want to make sure the type checker assigns + # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns + session_cls_mock = mocker.patch("neo4j._sync.driver.Session", + autospec=True) + driver = GraphDatabase.driver("bolt://localhost") + with driver as driver: + with pytest.warns(ExperimentalWarning, match="bookmark_manager"): + _ = driver.session(bookmark_manager=bmm) + session_cls_mock.assert_called_once() + assert session_cls_mock.call_args[0][1].bookmark_manager is bmm + + +@mark_sync_test +def test_with_custom_inherited_sync_bookmark_manager(mocker) -> None: + class BMM(BookmarkManager): + def update_bookmarks( + self, database: str, previous_bookmarks: t.Iterable[str], + new_bookmarks: t.Iterable[str] + ) -> None: + ... + + def get_bookmarks(self, database: str) -> t.Collection[str]: + ... + + def get_all_bookmarks(self) -> t.Collection[str]: + ... + + def forget(self, databases: t.Iterable[str]) -> None: + ... + + bmm = BMM() + # could be one line, but want to make sure the type checker assigns + # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns + session_cls_mock = mocker.patch("neo4j._sync.driver.Session", + autospec=True) + driver = GraphDatabase.driver("bolt://localhost") + with driver as driver: + with pytest.warns(ExperimentalWarning, match="bookmark_manager"): + _ = driver.session(bookmark_manager=bmm) + session_cls_mock.assert_called_once() + assert session_cls_mock.call_args[0][1].bookmark_manager is bmm + + +@TestDecorators.mark_async_only_test +def test_with_custom_ducktype_async_bookmark_manager(mocker) -> None: + class BMM: + def update_bookmarks( + self, database: str, previous_bookmarks: t.Iterable[str], + new_bookmarks: t.Iterable[str] + ) -> None: + ... + + def get_bookmarks(self, database: str) -> t.Collection[str]: + ... + + def get_all_bookmarks(self) -> t.Collection[str]: + ... + + def forget(self, databases: t.Iterable[str]) -> None: + ... + + bmm = BMM() + # could be one line, but want to make sure the type checker assigns + # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns + session_cls_mock = mocker.patch("neo4j._sync.driver.Session", + autospec=True) + driver = GraphDatabase.driver("bolt://localhost") + with driver as driver: + with pytest.warns(ExperimentalWarning, match="bookmark_manager"): + _ = driver.session(bookmark_manager=bmm) + session_cls_mock.assert_called_once() + assert session_cls_mock.call_args[0][1].bookmark_manager is bmm + + +@mark_sync_test +def test_with_custom_ducktype_sync_bookmark_manager(mocker) -> None: + class BMM: + def update_bookmarks( + self, database: str, previous_bookmarks: t.Iterable[str], + new_bookmarks: t.Iterable[str] + ) -> None: + ... + + def get_bookmarks(self, database: str) -> t.Collection[str]: + ... + + def get_all_bookmarks(self) -> t.Collection[str]: + ... + + def forget(self, databases: t.Iterable[str]) -> None: + ... + + bmm = BMM() + # could be one line, but want to make sure the type checker assigns + # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns + session_cls_mock = mocker.patch("neo4j._sync.driver.Session", + autospec=True) + driver = GraphDatabase.driver("bolt://localhost") + with driver as driver: + with pytest.warns(ExperimentalWarning, match="bookmark_manager"): + _ = driver.session(bookmark_manager=bmm) + session_cls_mock.assert_called_once() + assert session_cls_mock.call_args[0][1].bookmark_manager is bmm diff --git a/tests/unit/sync/work/__init__.py b/tests/unit/sync/work/__init__.py index 27923502c..c42cc6fb6 100644 --- a/tests/unit/sync/work/__init__.py +++ b/tests/unit/sync/work/__init__.py @@ -14,9 +14,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from ._fake_connection import ( - fake_connection, - fake_connection_generator, -) diff --git a/tests/unit/sync/work/_fake_connection.py b/tests/unit/sync/work/_fake_connection.py index 28132a4dc..d55006682 100644 --- a/tests/unit/sync/work/_fake_connection.py +++ b/tests/unit/sync/work/_fake_connection.py @@ -158,15 +158,15 @@ def __getattr__(self, name): parent = super() def build_message_handler(name): - try: - expected_message, scripted_callbacks = \ - self._script[self._script_pos] - except IndexError: - pytest.fail("End of scripted connection reached.") - assert name == expected_message - self._script_pos += 1 - def func(*args, **kwargs): + try: + expected_message, scripted_callbacks = \ + self._script[self._script_pos] + except IndexError: + pytest.fail("End of scripted connection reached.") + assert name == expected_message + self._script_pos += 1 + def callback(): for cb_name, default_cb_args in ( ("on_ignored", ({},)), @@ -176,8 +176,10 @@ def callback(): ("on_summary", ()), ): cb = kwargs.get(cb_name, None) - if (not callable(cb) - or cb_name not in scripted_callbacks): + if ( + not callable(cb) + or cb_name not in scripted_callbacks + ): continue cb_args = scripted_callbacks[cb_name] if cb_args is None: diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index dc5f71fc4..2c49b5ffb 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -28,26 +28,15 @@ unit_of_work, ) from neo4j._conf import SessionConfig -from neo4j._sync.io._pool import IOPool +from neo4j._sync.io import ( + BoltPool, + Neo4jPool, +) +from neo4j.api import BookmarkManager from ...._async_compat import mark_sync_test -@pytest.fixture() -def pool(fake_connection_generator, mocker): - pool = mocker.Mock(spec=IOPool) - assert not hasattr(pool, "acquired_connection_mocks") - pool.acquired_connection_mocks = [] - - def acquire_side_effect(*_, **__): - connection = fake_connection_generator() - pool.acquired_connection_mocks.append(connection) - return connection - - pool.acquire.side_effect = acquire_side_effect - return pool - - @mark_sync_test def test_session_context_calls_close(mocker): s = Session(None, SessionConfig()) @@ -66,9 +55,9 @@ def test_session_context_calls_close(mocker): )) @mark_sync_test def test_opens_connection_on_run( - pool, test_run_args, repetitions, consume + fake_pool, test_run_args, repetitions, consume ): - with Session(pool, SessionConfig()) as session: + with Session(fake_pool, SessionConfig()) as session: assert session._connection is None result = session.run(*test_run_args) assert session._connection is not None @@ -82,9 +71,9 @@ def test_opens_connection_on_run( @pytest.mark.parametrize("repetitions", range(1, 3)) @mark_sync_test def test_closes_connection_after_consume( - pool, test_run_args, repetitions + fake_pool, test_run_args, repetitions ): - with Session(pool, SessionConfig()) as session: + with Session(fake_pool, SessionConfig()) as session: result = session.run(*test_run_args) result.consume() assert session._connection is None @@ -96,9 +85,9 @@ def test_closes_connection_after_consume( )) @mark_sync_test def test_keeps_connection_until_last_result_consumed( - pool, test_run_args + fake_pool, test_run_args ): - with Session(pool, SessionConfig()) as session: + with Session(fake_pool, SessionConfig()) as session: result1 = session.run(*test_run_args) result2 = session.run(*test_run_args) assert session._connection is not None @@ -109,8 +98,8 @@ def test_keeps_connection_until_last_result_consumed( @mark_sync_test -def test_opens_connection_on_tx_begin(pool): - with Session(pool, SessionConfig()) as session: +def test_opens_connection_on_tx_begin(fake_pool): + with Session(fake_pool, SessionConfig()) as session: assert session._connection is None with session.begin_transaction() as _: assert session._connection is not None @@ -121,8 +110,10 @@ def test_opens_connection_on_tx_begin(pool): )) @pytest.mark.parametrize("repetitions", range(1, 3)) @mark_sync_test -def test_keeps_connection_on_tx_run(pool, test_run_args, repetitions): - with Session(pool, SessionConfig()) as session: +def test_keeps_connection_on_tx_run( + fake_pool, test_run_args, repetitions +): + with Session(fake_pool, SessionConfig()) as session: with session.begin_transaction() as tx: for _ in range(repetitions): tx.run(*test_run_args) @@ -135,9 +126,9 @@ def test_keeps_connection_on_tx_run(pool, test_run_args, repetitions): @pytest.mark.parametrize("repetitions", range(1, 3)) @mark_sync_test def test_keeps_connection_on_tx_consume( - pool, test_run_args, repetitions + fake_pool, test_run_args, repetitions ): - with Session(pool, SessionConfig()) as session: + with Session(fake_pool, SessionConfig()) as session: with session.begin_transaction() as tx: for _ in range(repetitions): result = tx.run(*test_run_args) @@ -149,8 +140,8 @@ def test_keeps_connection_on_tx_consume( ("RETURN $x", {"x": 1}), ("RETURN 1",) )) @mark_sync_test -def test_closes_connection_after_tx_close(pool, test_run_args): - with Session(pool, SessionConfig()) as session: +def test_closes_connection_after_tx_close(fake_pool, test_run_args): + with Session(fake_pool, SessionConfig()) as session: with session.begin_transaction() as tx: for _ in range(2): result = tx.run(*test_run_args) @@ -164,8 +155,8 @@ def test_closes_connection_after_tx_close(pool, test_run_args): ("RETURN $x", {"x": 1}), ("RETURN 1",) )) @mark_sync_test -def test_closes_connection_after_tx_commit(pool, test_run_args): - with Session(pool, SessionConfig()) as session: +def test_closes_connection_after_tx_commit(fake_pool, test_run_args): + with Session(fake_pool, SessionConfig()) as session: with session.begin_transaction() as tx: for _ in range(2): result = tx.run(*test_run_args) @@ -180,13 +171,13 @@ def test_closes_connection_after_tx_commit(pool, test_run_args): (None, [], ["abc"], ["foo", "bar"], {"a", "b"}, ("1", "two")) ) @mark_sync_test -def test_session_returns_bookmarks_directly(pool, bookmark_values): +def test_session_returns_bookmarks_directly(fake_pool, bookmark_values): if bookmark_values is not None: bookmarks = Bookmarks.from_raw_values(bookmark_values) else: bookmarks = Bookmarks() with Session( - pool, SessionConfig(bookmarks=bookmarks) + fake_pool, SessionConfig(bookmarks=bookmarks) ) as session: ret_bookmarks = (session.last_bookmarks()) assert isinstance(ret_bookmarks, Bookmarks) @@ -202,12 +193,13 @@ def test_session_returns_bookmarks_directly(pool, bookmark_values): (None, [], ["abc"], ["foo", "bar"], ("1", "two")) ) @mark_sync_test -def test_session_last_bookmark_is_deprecated(pool, bookmarks): +def test_session_last_bookmark_is_deprecated(fake_pool, bookmarks): if bookmarks is not None: with pytest.warns(DeprecationWarning): - session = Session(pool, SessionConfig(bookmarks=bookmarks)) + session = Session(fake_pool, + SessionConfig(bookmarks=bookmarks)) else: - session = Session(pool, SessionConfig(bookmarks=bookmarks)) + session = Session(fake_pool, SessionConfig(bookmarks=bookmarks)) with session: with pytest.warns(DeprecationWarning): if bookmarks: @@ -221,9 +213,11 @@ def test_session_last_bookmark_is_deprecated(pool, bookmarks): (("foo",), ("foo", "bar"), (), ["foo", "bar"], {"a", "b"}) ) @mark_sync_test -def test_session_bookmarks_as_iterable_is_deprecated(pool, bookmarks): +def test_session_bookmarks_as_iterable_is_deprecated( + fake_pool, bookmarks +): with pytest.warns(DeprecationWarning): - with Session(pool, SessionConfig( + with Session(fake_pool, SessionConfig( bookmarks=bookmarks )) as session: ret_bookmarks = (session.last_bookmarks()).raw_values @@ -237,19 +231,19 @@ def test_session_bookmarks_as_iterable_is_deprecated(pool, bookmarks): (["I don't", "think so"], TypeError), )) @mark_sync_test -def test_session_run_wrong_types(pool, query, error_type): - with Session(pool, SessionConfig()) as session: +def test_session_run_wrong_types(fake_pool, query, error_type): + with Session(fake_pool, SessionConfig()) as session: with pytest.raises(error_type): session.run(query) @pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction")) @mark_sync_test -def test_tx_function_argument_type(pool, tx_type): +def test_tx_function_argument_type(fake_pool, tx_type): def work(tx): assert isinstance(tx, ManagedTransaction) - with Session(pool, SessionConfig()) as session: + with Session(fake_pool, SessionConfig()) as session: getattr(session, tx_type)(work) @@ -262,18 +256,20 @@ def work(tx): )) @mark_sync_test -def test_decorated_tx_function_argument_type(pool, tx_type, decorator_kwargs): +def test_decorated_tx_function_argument_type( + fake_pool, tx_type, decorator_kwargs +): @unit_of_work(**decorator_kwargs) def work(tx): assert isinstance(tx, ManagedTransaction) - with Session(pool, SessionConfig()) as session: + with Session(fake_pool, SessionConfig()) as session: getattr(session, tx_type)(work) @mark_sync_test -def test_session_tx_type(pool): - with Session(pool, SessionConfig()) as session: +def test_session_tx_type(fake_pool): + with Session(fake_pool, SessionConfig()) as session: tx = session.begin_transaction() assert isinstance(tx, Transaction) @@ -300,9 +296,9 @@ def test_session_tx_type(pool): @pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed")) @mark_sync_test def test_session_run_with_parameters( - pool, parameters, run_type, mocker + fake_pool, parameters, run_type, mocker ): - with Session(pool, SessionConfig()) as session: + with Session(fake_pool, SessionConfig()) as session: if run_type == "auto": session.run("RETURN $x", **parameters) elif run_type == "unmanaged": @@ -315,9 +311,165 @@ def work(tx): else: raise ValueError(run_type) - assert len(pool.acquired_connection_mocks) == 1 - connection_mock = pool.acquired_connection_mocks[0] + assert len(fake_pool.acquired_connection_mocks) == 1 + connection_mock = fake_pool.acquired_connection_mocks[0] assert connection_mock.run.called_once() call = connection_mock.run.call_args assert call.args[0] == "RETURN $x" assert call.kwargs["parameters"] == parameters + + +@pytest.mark.parametrize("db", (None, "adb")) +@pytest.mark.parametrize("routing", (True, False)) +# no home db resolution when connected to Neo4j 4.3 or earlier +@pytest.mark.parametrize("home_db_gets_resolved", (True, False)) +@pytest.mark.parametrize("additional_session_bookmarks", + (None, ["session", "bookmarks"])) +@mark_sync_test +def test_with_bookmark_manager( + fake_pool, db, routing, scripted_connection, home_db_gets_resolved, + additional_session_bookmarks, mocker +): + def update_routing_table_side_effect( + database, imp_user, bookmarks, acquisition_timeout=None, + database_callback=None + ): + if home_db_gets_resolved: + database_callback("homedb") + + def bmm_get_bookmarks(database): + return [f"{database}:bm1"] + + def bmm_gat_all_bookmarks(): + return ["all", "bookmarks"] + + scripted_connection.set_script([ + ("run", {"on_success": None, "on_summary": None}), + ("pull", { + "on_success": ({"bookmark": "res:bm1", "has_more": False},), + "on_summary": None, + "on_records": None, + }) + ]) + fake_pool.buffered_connection_mocks.append(scripted_connection) + + bmm = mocker.Mock(spec=BookmarkManager) + bmm.get_bookmarks.side_effect = bmm_get_bookmarks + bmm.get_all_bookmarks.side_effect = bmm_gat_all_bookmarks + + if routing: + fake_pool.mock_add_spec(Neo4jPool) + fake_pool.update_routing_table.side_effect = \ + update_routing_table_side_effect + else: + fake_pool.mock_add_spec(BoltPool) + + config = SessionConfig() + config.bookmark_manager = bmm + if db is not None: + config.database = db + if additional_session_bookmarks: + config.bookmarks = Bookmarks.from_raw_values( + additional_session_bookmarks + ) + with Session(fake_pool, config) as session: + assert not bmm.method_calls + + session.run("RETURN 1") + + # assert called bmm accordingly + expected_bmm_method_calls = [mocker.call.get_bookmarks("system"), + mocker.call.get_all_bookmarks()] + if routing and db is None: + expected_bmm_method_calls = [ + # extra call for resolving the home database + mocker.call.get_bookmarks("system"), + *expected_bmm_method_calls + ] + assert bmm.method_calls == expected_bmm_method_calls + assert (bmm.get_bookmarks.call_count + == len(expected_bmm_method_calls) - 1) + bmm.get_all_bookmarks.assert_called_once() + bmm.method_calls.clear() + + expected_update_for_db = db + if not db: + if home_db_gets_resolved and routing: + expected_update_for_db = "homedb" + else: + expected_update_for_db = "" + assert [call[0] for call in bmm.method_calls] == ["update_bookmarks"] + assert bmm.method_calls[0].kwargs == {} + assert len(bmm.method_calls[0].args) == 3 + assert bmm.method_calls[0].args[0] == expected_update_for_db + assert (set(bmm.method_calls[0].args[1]) + == {"all", "bookmarks", *(additional_session_bookmarks or [])}) + assert set(bmm.method_calls[0].args[2]) == {"res:bm1"} + + expected_pool_method_calls = ["acquire", "release"] + if routing and db is None: + expected_pool_method_calls = ["update_routing_table", + *expected_pool_method_calls] + assert ([call[0] for call in fake_pool.method_calls] + == expected_pool_method_calls) + assert (set(fake_pool.acquire.call_args.kwargs["bookmarks"]) + == {"system:bm1", *(additional_session_bookmarks or [])}) + if routing and db is None: + assert ( + set(fake_pool.update_routing_table.call_args.kwargs["bookmarks"]) + == {"system:bm1", *(additional_session_bookmarks or [])} + ) + + assert len(fake_pool.acquired_connection_mocks) == 1 + connection_mock = fake_pool.acquired_connection_mocks[0] + assert connection_mock.run.called_once() + connection_run_call_kwargs = connection_mock.run.call_args.kwargs + assert (set(connection_run_call_kwargs["bookmarks"]) + == {"all", "bookmarks", *(additional_session_bookmarks or [])}) + + +@pytest.mark.parametrize("routing", (True, False)) +@pytest.mark.parametrize("session_method", ("run", "get_server_info")) +@mark_sync_test +def test_last_bookmarks_do_not_leak_bookmark_managers_bookmarks( + fake_pool, routing, session_method, mocker +): + def bmm_get_bookmarks(database): + return [f"bmm:{database}"] + + def bmm_gat_all_bookmarks(): + return ["bmm:all", "bookmarks"] + + fake_pool.mock_add_spec(Neo4jPool if routing else BoltPool) + + bmm = mocker.Mock(spec=BookmarkManager) + bmm.get_bookmarks.side_effect = bmm_get_bookmarks + bmm.get_all_bookmarks.side_effect = bmm_gat_all_bookmarks + + config = SessionConfig() + config.bookmark_manager = bmm + config.bookmarks = Bookmarks.from_raw_values(["session", "bookmarks"]) + with Session(fake_pool, config) as session: + if session_method == "run": + session.run("RETURN 1") + elif session_method == "get_server_info": + session._get_server_info() + else: + assert False + last_bookmarks = session.last_bookmarks() + + assert last_bookmarks.raw_values == {"session", "bookmarks"} + assert last_bookmarks.raw_values == {"session", "bookmarks"} + + +@mark_sync_test +def test_with_ignored_bookmark_manager(fake_pool, mocker): + bmm = mocker.Mock(spec=BookmarkManager) + session_config = SessionConfig() + session_config.bookmark_manager = bmm + session_config.ignore_bookmark_manager = True + with Session(fake_pool, session_config) as session: + session.run("RETURN 1") + + bmm.assert_not_called() + assert not bmm.method_calls